Commit before episodes episodes_stats merging
This commit is contained in:
97
tests/fixtures/dataset_factories.py
vendored
97
tests/fixtures/dataset_factories.py
vendored
@@ -9,13 +9,16 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
flatten_dict,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
@@ -33,10 +36,9 @@ class LeRobotDatasetFactory(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
def get_task_index(tasks: Dataset, task: str) -> int:
|
||||
task_idx = tasks["task"].index(task)
|
||||
return task_idx
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -104,9 +106,9 @@ def info_factory(features_factory):
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
files_size_in_mb: float = DEFAULT_FILE_SIZE_IN_MB,
|
||||
data_path: str = DEFAULT_DATA_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
@@ -120,8 +122,8 @@ def info_factory(features_factory):
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": chunks_size,
|
||||
"files_size_in_mb": files_size_in_mb,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
@@ -168,25 +170,25 @@ def episodes_stats_factory(stats_factory):
|
||||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
) -> dict:
|
||||
episodes_stats = {}
|
||||
for episode_index in range(total_episodes):
|
||||
episodes_stats[episode_index] = {
|
||||
"episode_index": episode_index,
|
||||
"stats": stats_factory(features),
|
||||
}
|
||||
return episodes_stats
|
||||
|
||||
def _generator(total_episodes):
|
||||
for ep_idx in range(total_episodes):
|
||||
flat_ep_stats = flatten_dict(stats_factory(features))
|
||||
flat_ep_stats["episode_index"] = ep_idx
|
||||
yield flat_ep_stats
|
||||
|
||||
# Simpler to rely on generator instead of from_dict
|
||||
return Dataset.from_generator(lambda: _generator(total_episodes))
|
||||
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
def _create_tasks(total_tasks: int = 3) -> Dataset:
|
||||
ids = list(range(total_tasks))
|
||||
tasks = [f"Perform action {i}." for i in ids]
|
||||
return Dataset.from_dict({"task_index": ids, "task": tasks})
|
||||
|
||||
return _create_tasks
|
||||
|
||||
@@ -196,6 +198,7 @@ def episodes_factory(tasks_factory):
|
||||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
video_keys: list[str] | None = None,
|
||||
tasks: dict | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
@@ -215,26 +218,41 @@ def episodes_factory(tasks_factory):
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
||||
num_tasks_available = len(tasks_list)
|
||||
num_tasks_available = len(tasks["task"])
|
||||
|
||||
episodes = {}
|
||||
remaining_tasks = tasks_list.copy()
|
||||
d = {
|
||||
"episode_index": [],
|
||||
"data/chunk_index": [],
|
||||
"data/file_index": [],
|
||||
"tasks": [],
|
||||
"length": [],
|
||||
}
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"{video_key}/chunk_index"] = []
|
||||
d[f"{video_key}/file_index"] = []
|
||||
|
||||
remaining_tasks = tasks["task"].copy()
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks["task"]
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
|
||||
episodes[ep_idx] = {
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
d["episode_index"].append(ep_idx)
|
||||
# TODO(rcadene): remove heuristic of only one file
|
||||
d["data/chunk_index"].append(0)
|
||||
d["data/file_index"].append(0)
|
||||
d["tasks"].append(episode_tasks)
|
||||
d["length"].append(lengths[ep_idx])
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"{video_key}/chunk_index"].append(0)
|
||||
d[f"{video_key}/file_index"].append(0)
|
||||
|
||||
return episodes
|
||||
return Dataset.from_dict(d)
|
||||
|
||||
return _create_episodes
|
||||
|
||||
@@ -258,7 +276,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes.values():
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
@@ -291,7 +309,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
},
|
||||
features=hf_features,
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
dataset.set_format("torch")
|
||||
return dataset
|
||||
|
||||
return _create_hf_dataset
|
||||
@@ -326,8 +344,9 @@ def lerobot_dataset_metadata_factory(
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], video_keys=video_keys, tasks=tasks
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
@@ -371,9 +390,9 @@ def lerobot_dataset_factory(
|
||||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
episodes_stats: datasets.Dataset | None = None,
|
||||
tasks: datasets.Dataset | None = None,
|
||||
episode_dicts: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
@@ -388,9 +407,11 @@ def lerobot_dataset_factory(
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
|
||||
84
tests/fixtures/files.py
vendored
84
tests/fixtures/files.py
vendored
@@ -7,83 +7,75 @@ import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
write_episodes,
|
||||
write_episodes_stats,
|
||||
write_hf_dataset,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
def create_info(info_factory):
|
||||
def _create_info(dir: Path, info: dict | None = None):
|
||||
if not info:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_info(info, dir)
|
||||
|
||||
return _create_info_json_file
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
def create_stats(stats_factory):
|
||||
def _create_stats(dir: Path, stats: dict | None = None):
|
||||
if not stats:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_stats(stats, dir)
|
||||
|
||||
return _create_stats_json_file
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
def create_episodes_stats(episodes_stats_factory):
|
||||
def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes_stats.values())
|
||||
return fpath
|
||||
write_episodes_stats(episodes_stats, dir)
|
||||
|
||||
return _create_episodes_stats_jsonl_file
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
def create_tasks(tasks_factory):
|
||||
def _create_tasks(dir: Path, tasks: Dataset | None = None):
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks.values())
|
||||
return fpath
|
||||
write_tasks(tasks, dir)
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
def create_episodes(episodes_factory):
|
||||
def _create_episodes(dir: Path, episodes: Dataset | None = None):
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes.values())
|
||||
return fpath
|
||||
write_episodes(episodes, dir)
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
return _create_episodes
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_hf_dataset(hf_dataset_factory):
|
||||
def _create_hf_dataset(dir: Path, hf_dataset: Dataset | None = None):
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
write_hf_dataset(hf_dataset, dir)
|
||||
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -91,6 +83,7 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
raise NotImplementedError()
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
@@ -114,6 +107,7 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
raise NotImplementedError()
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
|
||||
104
tests/fixtures/hub.py
vendored
104
tests/fixtures/hub.py
vendored
@@ -5,11 +5,12 @@ import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_EPISODES_STATS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
@@ -17,17 +18,17 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info_factory,
|
||||
info_path,
|
||||
create_info,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
create_stats,
|
||||
episodes_stats_factory,
|
||||
episodes_stats_path,
|
||||
create_episodes_stats,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
create_tasks,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
create_episodes,
|
||||
hf_dataset_factory,
|
||||
create_hf_dataset,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
@@ -37,9 +38,9 @@ def mock_snapshot_download_factory(
|
||||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
episodes_stats: datasets.Dataset | None = None,
|
||||
tasks: datasets.Dataset | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
):
|
||||
if not info:
|
||||
@@ -59,14 +60,6 @@ def mock_snapshot_download_factory(
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
local_dir: str | Path | None = None,
|
||||
@@ -79,40 +72,55 @@ def mock_snapshot_download_factory(
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes.values():
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
all_files = [
|
||||
INFO_PATH,
|
||||
LEGACY_STATS_PATH,
|
||||
# TODO(rcadene)
|
||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_STATS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
has_info = False
|
||||
has_tasks = False
|
||||
has_episodes = False
|
||||
has_episodes_stats = False
|
||||
has_stats = False
|
||||
has_data = False
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == EPISODES_STATS_PATH:
|
||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes)
|
||||
if rel_path.startswith("meta/info.json"):
|
||||
has_info = True
|
||||
elif rel_path.startswith("meta/stats"):
|
||||
has_stats = True
|
||||
elif rel_path.startswith("meta/tasks"):
|
||||
has_tasks = True
|
||||
elif rel_path.startswith("meta/episodes_stats"):
|
||||
has_episodes_stats = True
|
||||
elif rel_path.startswith("meta/episodes"):
|
||||
has_episodes = True
|
||||
elif rel_path.startswith("data/"):
|
||||
has_data = True
|
||||
else:
|
||||
pass
|
||||
raise ValueError(f"{rel_path} not supported.")
|
||||
|
||||
if has_info:
|
||||
create_info(local_dir, info)
|
||||
if has_stats:
|
||||
create_stats(local_dir, stats)
|
||||
if has_tasks:
|
||||
create_tasks(local_dir, tasks)
|
||||
if has_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
if has_episodes_stats:
|
||||
create_episodes_stats(local_dir, episodes_stats)
|
||||
if has_data:
|
||||
create_hf_dataset(local_dir, hf_dataset)
|
||||
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
||||
Reference in New Issue
Block a user