Progress on aggregate_datasets

This commit is contained in:
Remi Cadene
2025-04-19 19:11:53 +05:30
parent 54b5c805bf
commit b0cca75e5e
7 changed files with 149 additions and 87 deletions

View File

@@ -212,6 +212,7 @@ def tasks_factory():
def episodes_factory(tasks_factory, stats_factory):
def _create_episodes(
features: dict[str],
fps: int = DEFAULT_FPS,
total_episodes: int = 3,
total_frames: int = 400,
video_keys: list[str] | None = None,
@@ -252,6 +253,8 @@ def episodes_factory(tasks_factory, stats_factory):
for video_key in video_keys:
d[f"videos/{video_key}/chunk_index"] = []
d[f"videos/{video_key}/file_index"] = []
d[f"videos/{video_key}/from_timestamp"] = []
d[f"videos/{video_key}/to_timestamp"] = []
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
@@ -281,6 +284,8 @@ def episodes_factory(tasks_factory, stats_factory):
for video_key in video_keys:
d[f"videos/{video_key}/chunk_index"].append(0)
d[f"videos/{video_key}/file_index"].append(0)
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
# Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
@@ -306,7 +311,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
if features is None:
features = features_factory()
if episodes is None:
episodes = episodes_factory(features)
episodes = episodes_factory(features, fps)
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)
@@ -379,6 +384,7 @@ def lerobot_dataset_metadata_factory(
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
@@ -441,6 +447,7 @@ def lerobot_dataset_factory(
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,

View File

@@ -73,6 +73,7 @@ def create_tasks(tasks_factory):
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
if episodes is None:
# TODO(rcadene): add features, fps as arguments
episodes = episodes_factory()
write_episodes(episodes, dir)

View File

@@ -62,6 +62,7 @@ def mock_snapshot_download_factory(
if episodes is None:
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
@@ -121,6 +122,7 @@ def mock_snapshot_download_factory(
create_tasks(local_dir, tasks)
if has_episodes:
create_episodes(local_dir, episodes)
# TODO(rcadene): create_videos?
if has_data:
create_hf_dataset(local_dir, hf_dataset)

View File

@@ -1,19 +1,29 @@
from lerobot.common.datasets.aggregate import aggregate_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
dataset_0 = lerobot_dataset_factory(
ds_0 = lerobot_dataset_factory(
root=tmp_path / "test_0",
repo_id=DUMMY_REPO_ID + "_0",
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=10,
total_frames=400,
)
dataset_1 = lerobot_dataset_factory(
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=DUMMY_REPO_ID + "_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=10,
total_frames=400,
)
dataset_2 = aggregate_datasets([dataset_0, dataset_1])
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr"
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
for item in aggr_ds:
pass