forked from tangger/lerobot
Progress on aggregate_datasets
This commit is contained in:
9
tests/fixtures/dataset_factories.py
vendored
9
tests/fixtures/dataset_factories.py
vendored
@@ -212,6 +212,7 @@ def tasks_factory():
|
||||
def episodes_factory(tasks_factory, stats_factory):
|
||||
def _create_episodes(
|
||||
features: dict[str],
|
||||
fps: int = DEFAULT_FPS,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
video_keys: list[str] | None = None,
|
||||
@@ -252,6 +253,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"] = []
|
||||
d[f"videos/{video_key}/file_index"] = []
|
||||
d[f"videos/{video_key}/from_timestamp"] = []
|
||||
d[f"videos/{video_key}/to_timestamp"] = []
|
||||
|
||||
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||
d[stats_key] = []
|
||||
@@ -281,6 +284,8 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"].append(0)
|
||||
d[f"videos/{video_key}/file_index"].append(0)
|
||||
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
|
||||
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
|
||||
|
||||
# Add stats columns like "stats/action/max"
|
||||
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||
@@ -306,7 +311,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
if features is None:
|
||||
features = features_factory()
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(features)
|
||||
episodes = episodes_factory(features, fps)
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
@@ -379,6 +384,7 @@ def lerobot_dataset_metadata_factory(
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
@@ -441,6 +447,7 @@ def lerobot_dataset_factory(
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
|
||||
1
tests/fixtures/files.py
vendored
1
tests/fixtures/files.py
vendored
@@ -73,6 +73,7 @@ def create_tasks(tasks_factory):
|
||||
def create_episodes(episodes_factory):
|
||||
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||
if episodes is None:
|
||||
# TODO(rcadene): add features, fps as arguments
|
||||
episodes = episodes_factory()
|
||||
write_episodes(episodes, dir)
|
||||
|
||||
|
||||
2
tests/fixtures/hub.py
vendored
2
tests/fixtures/hub.py
vendored
@@ -62,6 +62,7 @@ def mock_snapshot_download_factory(
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
@@ -121,6 +122,7 @@ def mock_snapshot_download_factory(
|
||||
create_tasks(local_dir, tasks)
|
||||
if has_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
# TODO(rcadene): create_videos?
|
||||
if has_data:
|
||||
create_hf_dataset(local_dir, hf_dataset)
|
||||
|
||||
|
||||
@@ -1,19 +1,29 @@
|
||||
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
dataset_0 = lerobot_dataset_factory(
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_0",
|
||||
repo_id=DUMMY_REPO_ID + "_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
)
|
||||
dataset_1 = lerobot_dataset_factory(
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=DUMMY_REPO_ID + "_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
)
|
||||
|
||||
dataset_2 = aggregate_datasets([dataset_0, dataset_1])
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||
aggr_root=tmp_path / "test_aggr"
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
for item in aggr_ds:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user