Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

29
tests/fixtures/constants.py vendored Normal file
View File

@@ -0,0 +1,29 @@
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {
"video.fps": DEFAULT_FPS,
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}

396
tests/fixtures/dataset_factories.py vendored Normal file
View File

@@ -0,0 +1,396 @@
import random
from pathlib import Path
from unittest.mock import patch
import datasets
import numpy as np
import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
DUMMY_VIDEO_INFO,
)
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
else:
raise ValueError(dtype)
return img_array
return _create_img_array
@pytest.fixture(scope="session")
def img_factory(img_array_factory):
def _create_img(height=100, width=100) -> PIL.Image.Image:
img_array = img_array_factory(height=height, width=width)
return PIL.Image.fromarray(img_array)
return _create_img
@pytest.fixture(scope="session")
def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
return {
**motor_features,
**camera_ft,
**DEFAULT_FEATURES,
}
return _create_features
@pytest.fixture(scope="session")
def info_factory(features_factory):
def _create_info(
codebase_version: str = CODEBASE_VERSION,
fps: int = DEFAULT_FPS,
robot_type: str = DUMMY_ROBOT_TYPE,
total_episodes: int = 0,
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
features = features_factory(motor_features, camera_features, use_videos)
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": chunks_size,
"fps": fps,
"splits": {},
"data_path": data_path,
"video_path": video_path if use_videos else None,
"features": features,
}
return _create_info
@pytest.fixture(scope="session")
def stats_factory():
def _create_stats(
features: dict[str] | None = None,
) -> dict:
stats = {}
for key, ft in features.items():
shape = ft["shape"]
dtype = ft["dtype"]
if dtype in ["image", "video"]:
stats[key] = {
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
}
else:
stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(),
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
}
return stats
return _create_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks_list = []
for i in range(total_tasks):
task_dict = {"task_index": i, "task": f"Perform action {i}."}
tasks_list.append(task_dict)
return tasks_list
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def _create_episodes(
total_episodes: int = 3,
total_frames: int = 400,
tasks: dict | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks]
num_tasks_available = len(tasks_list)
episodes_list = []
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
episodes_list.append(
{
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
)
return episodes_list
return _create_episodes
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
tasks = tasks_factory()
if not episodes:
episodes = episodes_factory()
if not features:
features = features_factory()
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
index_col = np.arange(len(episode_index_col))
robot_cols = {}
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
{
**robot_cols,
"timestamp": timestamp_col,
"frame_index": frame_index_col,
"episode_index": episode_index_col,
"index": index_col,
"task_index": task_index,
},
features=hf_features,
)
dataset.set_transform(hf_transform_to_torch)
return dataset
return _create_hf_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
local_files_only: bool = False,
) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
tasks=tasks,
episodes=episodes,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
episode_dicts = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset

114
tests/fixtures/files.py vendored Normal file
View File

@@ -0,0 +1,114 @@
import json
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
@pytest.fixture(scope="session")
def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
if not info:
info = info_factory()
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
return _create_info_json_file
@pytest.fixture(scope="session")
def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
if not stats:
stats = stats_factory()
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
return _create_stats_json_file
@pytest.fixture(scope="session")
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
if not tasks:
tasks = tasks_factory()
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks)
return fpath
return _create_tasks_jsonl_file
@pytest.fixture(scope="session")
def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
if not episodes:
episodes = episodes_factory()
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes)
return fpath
return _create_episodes_jsonl_file
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
return _create_single_episode_parquet
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"
return _create_multi_episode_parquet

105
tests/fixtures/hub.py vendored Normal file
View File

@@ -0,0 +1,105 @@
from pathlib import Path
import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from tests.fixtures.constants import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info_factory,
info_path,
stats_factory,
stats_path,
tasks_factory,
tasks_path,
episodes_factory,
episode_path,
single_episode_parquet_path,
hf_dataset_factory,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
"""
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
return episode_index
else:
return None
def _mock_snapshot_download(
repo_id: str,
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
# Create allowed files
for rel_path in allowed_files:
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes)
else:
pass
return str(local_dir)
return _mock_snapshot_download
return _mock_snapshot_download_func