forked from tangger/lerobot
Fix tests
This commit is contained in:
348
tests/fixtures/dataset_factories.py
vendored
348
tests/fixtures/dataset_factories.py
vendored
@@ -11,16 +11,19 @@ import torch
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.defaults import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_KEYS,
|
||||
DUMMY_KEYS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
DUMMY_VIDEO_INFO,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,16 +76,33 @@ def img_factory(img_array_factory):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory():
|
||||
def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
}
|
||||
else:
|
||||
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
return _create_features
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory(features_factory):
|
||||
def _create_info(
|
||||
codebase_version: str = CODEBASE_VERSION,
|
||||
fps: int = DEFAULT_FPS,
|
||||
robot_type: str = DUMMY_ROBOT_TYPE,
|
||||
keys: list[str] = DUMMY_KEYS,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = DUMMY_CAMERA_KEYS,
|
||||
shapes: dict | None = None,
|
||||
names: dict | None = None,
|
||||
total_episodes: int = 0,
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
@@ -90,30 +110,14 @@ def info_factory():
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
videos_path: str = DEFAULT_VIDEO_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if not image_keys:
|
||||
image_keys = []
|
||||
if not shapes:
|
||||
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
|
||||
if not names:
|
||||
names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys}
|
||||
|
||||
video_info = {"videos_path": videos_path}
|
||||
for key in video_keys:
|
||||
video_info[key] = {
|
||||
"video.fps": fps,
|
||||
"video.width": shapes[key]["width"],
|
||||
"video.height": shapes[key]["height"],
|
||||
"video.channels": shapes[key]["channels"],
|
||||
"video.codec": "av1",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"data_path": data_path,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": total_frames,
|
||||
@@ -123,12 +127,9 @@ def info_factory():
|
||||
"chunks_size": chunks_size,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"keys": keys,
|
||||
"video_keys": video_keys,
|
||||
"image_keys": image_keys,
|
||||
"shapes": shapes,
|
||||
"names": names,
|
||||
"videos": video_info if len(video_keys) > 0 else None,
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
return _create_info
|
||||
@@ -137,32 +138,26 @@ def info_factory():
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_factory():
|
||||
def _create_stats(
|
||||
keys: list[str] = DUMMY_KEYS,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = DUMMY_CAMERA_KEYS,
|
||||
shapes: dict | None = None,
|
||||
features: dict[str] | None = None,
|
||||
) -> dict:
|
||||
if not image_keys:
|
||||
image_keys = []
|
||||
if not shapes:
|
||||
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
|
||||
stats = {}
|
||||
for key in keys:
|
||||
shape = shapes[key]
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full(shape, 0, dtype=np.float32).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
for key in [*image_keys, *video_keys]:
|
||||
shape = (3, 1, 1)
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full(shape, 0, dtype=np.float32).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
dtype = ft["dtype"]
|
||||
if dtype in ["image", "video"]:
|
||||
stats[key] = {
|
||||
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
|
||||
"min": np.full(shape, 0, dtype=dtype).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
|
||||
}
|
||||
return stats
|
||||
|
||||
return _create_stats
|
||||
@@ -185,7 +180,7 @@ def episodes_factory(tasks_factory):
|
||||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
task_dicts: dict | None = None,
|
||||
tasks: dict | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
@@ -193,18 +188,18 @@ def episodes_factory(tasks_factory):
|
||||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
|
||||
if not task_dicts:
|
||||
if not tasks:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
total_tasks = random.randint(min_tasks, total_episodes)
|
||||
task_dicts = tasks_factory(total_tasks)
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(task_dicts) and not multi_task:
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in task_dicts]
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks]
|
||||
num_tasks_available = len(tasks_list)
|
||||
|
||||
episodes_list = []
|
||||
@@ -231,81 +226,56 @@ def episodes_factory(tasks_factory):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(img_array_factory, episodes, tasks):
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
episode_dicts: list[dict] = episodes,
|
||||
task_dicts: list[dict] = tasks,
|
||||
keys: list[str] = DUMMY_KEYS,
|
||||
image_keys: list[str] | None = None,
|
||||
shapes: dict | None = None,
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
fps: int = DEFAULT_FPS,
|
||||
) -> datasets.Dataset:
|
||||
if not image_keys:
|
||||
image_keys = []
|
||||
if not shapes:
|
||||
shapes = make_dummy_shapes(keys=keys, camera_keys=image_keys)
|
||||
key_features = {
|
||||
key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32"))
|
||||
for key in keys
|
||||
}
|
||||
image_features = {key: datasets.Image() for key in image_keys} if image_keys else {}
|
||||
common_features = {
|
||||
"episode_index": datasets.Value(dtype="int64"),
|
||||
"frame_index": datasets.Value(dtype="int64"),
|
||||
"timestamp": datasets.Value(dtype="float32"),
|
||||
"next.done": datasets.Value(dtype="bool"),
|
||||
"index": datasets.Value(dtype="int64"),
|
||||
"task_index": datasets.Value(dtype="int64"),
|
||||
}
|
||||
features = datasets.Features(
|
||||
{
|
||||
**key_features,
|
||||
**image_features,
|
||||
**common_features,
|
||||
}
|
||||
)
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
if not features:
|
||||
features = features_factory()
|
||||
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
next_done_col = np.array([], dtype=bool)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
|
||||
for ep_dict in episode_dicts:
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
)
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
next_done_ep = np.full(ep_dict["length"], False, dtype=bool)
|
||||
next_done_ep[-1] = True
|
||||
next_done_col = np.concatenate((next_done_col, next_done_ep))
|
||||
ep_task_index = get_task_index(task_dicts, ep_dict["tasks"][0])
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
|
||||
index_col = np.arange(len(episode_index_col))
|
||||
key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys}
|
||||
|
||||
image_cols = {}
|
||||
if image_keys:
|
||||
for key in image_keys:
|
||||
image_cols[key] = [
|
||||
img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"])
|
||||
robot_cols = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(width=ft["shapes"][0], height=ft["shapes"][1])
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
|
||||
hf_features = get_hf_features_from_features(features)
|
||||
dataset = datasets.Dataset.from_dict(
|
||||
{
|
||||
**key_cols,
|
||||
**image_cols,
|
||||
"episode_index": episode_index_col,
|
||||
"frame_index": frame_index_col,
|
||||
**robot_cols,
|
||||
"timestamp": timestamp_col,
|
||||
"next.done": next_done_col,
|
||||
"frame_index": frame_index_col,
|
||||
"episode_index": episode_index_col,
|
||||
"index": index_col,
|
||||
"task_index": task_index,
|
||||
},
|
||||
features=features,
|
||||
features=hf_features,
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
return dataset
|
||||
@@ -315,26 +285,37 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info,
|
||||
stats,
|
||||
tasks,
|
||||
episodes,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_metadata(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
**kwargs,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
@@ -347,48 +328,68 @@ def lerobot_dataset_metadata_factory(
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(
|
||||
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
|
||||
)
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info,
|
||||
stats,
|
||||
tasks,
|
||||
episodes,
|
||||
hf_dataset,
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
hf_ds: datasets.Dataset = hf_dataset,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
hf_ds=hf_ds,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
hf_dataset=hf_dataset,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
**kwargs,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
@@ -402,44 +403,3 @@ def lerobot_dataset_factory(
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
return _create_lerobot_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_from_episodes_factory(
|
||||
info_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
lerobot_dataset_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_total_episodes(
|
||||
root: Path,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
**kwargs,
|
||||
):
|
||||
info_dict = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
task_dicts = tasks_factory(total_tasks)
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
task_dicts=task_dicts,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
|
||||
return lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
hf_ds=hf_dataset,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _create_lerobot_dataset_total_episodes
|
||||
|
||||
Reference in New Issue
Block a user