Fix tests

This commit is contained in:
Simon Alibert
2024-11-05 19:09:12 +01:00
parent aed9f4036a
commit f3630ad910
13 changed files with 437 additions and 496 deletions

View File

@@ -1,67 +0,0 @@
import datasets
import pytest
from lerobot.common.datasets.utils import get_episode_data_index
from tests.fixtures.defaults import DUMMY_CAMERA_KEYS
@pytest.fixture(scope="session")
def empty_info(info_factory) -> dict:
return info_factory(
keys=[],
image_keys=[],
video_keys=[],
shapes={},
names={},
)
@pytest.fixture(scope="session")
def info(info_factory) -> dict:
return info_factory(
total_episodes=4,
total_frames=420,
total_tasks=3,
total_videos=8,
total_chunks=1,
)
@pytest.fixture(scope="session")
def stats(stats_factory) -> list:
return stats_factory()
@pytest.fixture(scope="session")
def tasks() -> list:
return [
{"task_index": 0, "task": "Pick up the block."},
{"task_index": 1, "task": "Open the box."},
{"task_index": 2, "task": "Make paperclips."},
]
@pytest.fixture(scope="session")
def episodes() -> list:
return [
{"episode_index": 0, "tasks": ["Pick up the block."], "length": 100},
{"episode_index": 1, "tasks": ["Open the box."], "length": 80},
{"episode_index": 2, "tasks": ["Pick up the block."], "length": 90},
{"episode_index": 3, "tasks": ["Make paperclips."], "length": 150},
]
@pytest.fixture(scope="session")
def episode_data_index(episodes) -> dict:
return get_episode_data_index(episodes)
@pytest.fixture(scope="session")
def hf_dataset(hf_dataset_factory) -> datasets.Dataset:
return hf_dataset_factory()
@pytest.fixture(scope="session")
def hf_dataset_image(hf_dataset_factory) -> datasets.Dataset:
image_keys = DUMMY_CAMERA_KEYS
return hf_dataset_factory(image_keys=image_keys)

View File

@@ -11,16 +11,19 @@ import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.defaults import (
DEFAULT_FPS,
DUMMY_CAMERA_KEYS,
DUMMY_KEYS,
DUMMY_CAMERA_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
DUMMY_VIDEO_INFO,
)
@@ -73,16 +76,33 @@ def img_factory(img_array_factory):
@pytest.fixture(scope="session")
def info_factory():
def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
return {
**motor_features,
**camera_ft,
**DEFAULT_FEATURES,
}
return _create_features
@pytest.fixture(scope="session")
def info_factory(features_factory):
def _create_info(
codebase_version: str = CODEBASE_VERSION,
fps: int = DEFAULT_FPS,
robot_type: str = DUMMY_ROBOT_TYPE,
keys: list[str] = DUMMY_KEYS,
image_keys: list[str] | None = None,
video_keys: list[str] = DUMMY_CAMERA_KEYS,
shapes: dict | None = None,
names: dict | None = None,
total_episodes: int = 0,
total_frames: int = 0,
total_tasks: int = 0,
@@ -90,30 +110,14 @@ def info_factory():
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH,
videos_path: str = DEFAULT_VIDEO_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
if not image_keys:
image_keys = []
if not shapes:
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
if not names:
names = {key: [f"motor_{i}" for i in range(shapes[key])] for key in keys}
video_info = {"videos_path": videos_path}
for key in video_keys:
video_info[key] = {
"video.fps": fps,
"video.width": shapes[key]["width"],
"video.height": shapes[key]["height"],
"video.channels": shapes[key]["channels"],
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}
features = features_factory(motor_features, camera_features, use_videos)
return {
"codebase_version": codebase_version,
"data_path": data_path,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": total_frames,
@@ -123,12 +127,9 @@ def info_factory():
"chunks_size": chunks_size,
"fps": fps,
"splits": {},
"keys": keys,
"video_keys": video_keys,
"image_keys": image_keys,
"shapes": shapes,
"names": names,
"videos": video_info if len(video_keys) > 0 else None,
"data_path": data_path,
"video_path": video_path if use_videos else None,
"features": features,
}
return _create_info
@@ -137,32 +138,26 @@ def info_factory():
@pytest.fixture(scope="session")
def stats_factory():
def _create_stats(
keys: list[str] = DUMMY_KEYS,
image_keys: list[str] | None = None,
video_keys: list[str] = DUMMY_CAMERA_KEYS,
shapes: dict | None = None,
features: dict[str] | None = None,
) -> dict:
if not image_keys:
image_keys = []
if not shapes:
shapes = make_dummy_shapes(keys=keys, camera_keys=[*image_keys, *video_keys])
stats = {}
for key in keys:
shape = shapes[key]
stats[key] = {
"max": np.full(shape, 1, dtype=np.float32).tolist(),
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
"min": np.full(shape, 0, dtype=np.float32).tolist(),
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
}
for key in [*image_keys, *video_keys]:
shape = (3, 1, 1)
stats[key] = {
"max": np.full(shape, 1, dtype=np.float32).tolist(),
"mean": np.full(shape, 0.5, dtype=np.float32).tolist(),
"min": np.full(shape, 0, dtype=np.float32).tolist(),
"std": np.full(shape, 0.25, dtype=np.float32).tolist(),
}
for key, ft in features.items():
shape = ft["shape"]
dtype = ft["dtype"]
if dtype in ["image", "video"]:
stats[key] = {
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
}
else:
stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(),
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
}
return stats
return _create_stats
@@ -185,7 +180,7 @@ def episodes_factory(tasks_factory):
def _create_episodes(
total_episodes: int = 3,
total_frames: int = 400,
task_dicts: dict | None = None,
tasks: dict | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
@@ -193,18 +188,18 @@ def episodes_factory(tasks_factory):
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not task_dicts:
if not tasks:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
task_dicts = tasks_factory(total_tasks)
tasks = tasks_factory(total_tasks)
if total_episodes < len(task_dicts) and not multi_task:
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in task_dicts]
tasks_list = [task_dict["task"] for task_dict in tasks]
num_tasks_available = len(tasks_list)
episodes_list = []
@@ -231,81 +226,56 @@ def episodes_factory(tasks_factory):
@pytest.fixture(scope="session")
def hf_dataset_factory(img_array_factory, episodes, tasks):
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
episode_dicts: list[dict] = episodes,
task_dicts: list[dict] = tasks,
keys: list[str] = DUMMY_KEYS,
image_keys: list[str] | None = None,
shapes: dict | None = None,
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not image_keys:
image_keys = []
if not shapes:
shapes = make_dummy_shapes(keys=keys, camera_keys=image_keys)
key_features = {
key: datasets.Sequence(length=shapes[key], feature=datasets.Value(dtype="float32"))
for key in keys
}
image_features = {key: datasets.Image() for key in image_keys} if image_keys else {}
common_features = {
"episode_index": datasets.Value(dtype="int64"),
"frame_index": datasets.Value(dtype="int64"),
"timestamp": datasets.Value(dtype="float32"),
"next.done": datasets.Value(dtype="bool"),
"index": datasets.Value(dtype="int64"),
"task_index": datasets.Value(dtype="int64"),
}
features = datasets.Features(
{
**key_features,
**image_features,
**common_features,
}
)
if not tasks:
tasks = tasks_factory()
if not episodes:
episodes = episodes_factory()
if not features:
features = features_factory()
episode_index_col = np.array([], dtype=np.int64)
frame_index_col = np.array([], dtype=np.int64)
timestamp_col = np.array([], dtype=np.float32)
next_done_col = np.array([], dtype=bool)
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episode_dicts:
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
next_done_ep = np.full(ep_dict["length"], False, dtype=bool)
next_done_ep[-1] = True
next_done_col = np.concatenate((next_done_col, next_done_ep))
ep_task_index = get_task_index(task_dicts, ep_dict["tasks"][0])
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
index_col = np.arange(len(episode_index_col))
key_cols = {key: np.random.random((len(index_col), shapes[key])).astype(np.float32) for key in keys}
image_cols = {}
if image_keys:
for key in image_keys:
image_cols[key] = [
img_array_factory(width=shapes[key]["width"], height=shapes[key]["height"])
robot_cols = {}
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(width=ft["shapes"][0], height=ft["shapes"][1])
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
{
**key_cols,
**image_cols,
"episode_index": episode_index_col,
"frame_index": frame_index_col,
**robot_cols,
"timestamp": timestamp_col,
"next.done": next_done_col,
"frame_index": frame_index_col,
"episode_index": episode_index_col,
"index": index_col,
"task_index": task_index,
},
features=features,
features=hf_features,
)
dataset.set_transform(hf_transform_to_torch)
return dataset
@@ -315,26 +285,37 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info,
stats,
tasks,
episodes,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info,
stats_dict: dict = stats,
task_dicts: list[dict] = tasks,
episode_dicts: list[dict] = episodes,
**kwargs,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
local_files_only: bool = False,
) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
info=info,
stats=stats,
tasks=tasks,
episodes=episodes,
)
with (
patch(
@@ -347,48 +328,68 @@ def lerobot_dataset_metadata_factory(
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
)
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info,
stats,
tasks,
episodes,
hf_dataset,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info,
stats_dict: dict = stats,
task_dicts: list[dict] = tasks,
episode_dicts: list[dict] = episodes,
hf_ds: datasets.Dataset = hf_dataset,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
episode_dicts = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
hf_ds=hf_ds,
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
**kwargs,
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
@@ -402,44 +403,3 @@ def lerobot_dataset_factory(
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_from_episodes_factory(
info_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
lerobot_dataset_factory,
):
def _create_lerobot_dataset_total_episodes(
root: Path,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
repo_id: str = DUMMY_REPO_ID,
**kwargs,
):
info_dict = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
task_dicts = tasks_factory(total_tasks)
episode_dicts = episodes_factory(
total_episodes=total_episodes,
total_frames=total_frames,
task_dicts=task_dicts,
multi_task=multi_task,
)
hf_dataset = hf_dataset_factory(episode_dicts=episode_dicts, task_dicts=task_dicts)
return lerobot_dataset_factory(
root=root,
repo_id=repo_id,
info_dict=info_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
hf_ds=hf_dataset,
**kwargs,
)
return _create_lerobot_dataset_total_episodes

View File

@@ -3,6 +3,27 @@ from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_KEYS = ["state", "action"]
DUMMY_CAMERA_KEYS = ["laptop", "phone"]
DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
"phone": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {
"video.fps": DEFAULT_FPS,
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}

View File

@@ -11,64 +11,77 @@ from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH,
@pytest.fixture(scope="session")
def info_path(info):
def _create_info_json_file(dir: Path, info_dict: dict = info) -> Path:
def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
if not info:
info = info_factory()
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info_dict, f, indent=4, ensure_ascii=False)
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
return _create_info_json_file
@pytest.fixture(scope="session")
def stats_path(stats):
def _create_stats_json_file(dir: Path, stats_dict: dict = stats) -> Path:
def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
if not stats:
stats = stats_factory()
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats_dict, f, indent=4, ensure_ascii=False)
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
return _create_stats_json_file
@pytest.fixture(scope="session")
def tasks_path(tasks):
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
if not tasks:
tasks = tasks_factory()
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(task_dicts)
writer.write_all(tasks)
return fpath
return _create_tasks_jsonl_file
@pytest.fixture(scope="session")
def episode_path(episodes):
def _create_episodes_jsonl_file(dir: Path, episode_dicts: list = episodes) -> Path:
def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
if not episodes:
episodes = episodes_factory()
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episode_dicts)
writer.write_all(episodes)
return fpath
return _create_episodes_jsonl_file
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset, info):
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, hf_ds: datasets.Dataset = hf_dataset, ep_idx: int = 0
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_ds.data.table
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
@@ -77,8 +90,15 @@ def single_episode_parquet_path(hf_dataset, info):
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset, info):
def _create_multi_episode_parquet(dir: Path, hf_ds: datasets.Dataset = hf_dataset) -> Path:
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
@@ -86,7 +106,7 @@ def multi_episode_parquet_path(hf_dataset, info):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_ds.data.table
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"

46
tests/fixtures/hub.py vendored
View File

@@ -1,5 +1,6 @@
from pathlib import Path
import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
@@ -9,16 +10,16 @@ from tests.fixtures.defaults import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info,
info_factory,
info_path,
stats,
stats_factory,
stats_path,
tasks,
tasks_factory,
tasks_path,
episodes,
episodes_factory,
episode_path,
single_episode_parquet_path,
hf_dataset,
hf_dataset_factory,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
@@ -26,8 +27,25 @@ def mock_snapshot_download_factory(
"""
def _mock_snapshot_download_func(
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
@@ -53,10 +71,10 @@ def mock_snapshot_download_factory(
all_files.extend(meta_files)
data_files = []
for episode_dict in episode_dicts:
for episode_dict in episodes:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info_dict["chunks_size"]
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
@@ -69,15 +87,15 @@ def mock_snapshot_download_factory(
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, hf_ds, ep_idx=episode_index)
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info_dict)
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats_dict)
_ = stats_path(local_dir, stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, task_dicts)
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episode_dicts)
_ = episode_path(local_dir, episodes)
else:
pass
return str(local_dir)