Progress on aggregate_datasets

This commit is contained in:
Remi Cadene
2025-04-19 19:11:53 +05:30
parent 54b5c805bf
commit b0cca75e5e
7 changed files with 149 additions and 87 deletions

View File

@@ -1,19 +1,29 @@
from lerobot.common.datasets.aggregate import aggregate_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
dataset_0 = lerobot_dataset_factory(
ds_0 = lerobot_dataset_factory(
root=tmp_path / "test_0",
repo_id=DUMMY_REPO_ID + "_0",
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=10,
total_frames=400,
)
dataset_1 = lerobot_dataset_factory(
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=DUMMY_REPO_ID + "_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=10,
total_frames=400,
)
dataset_2 = aggregate_datasets([dataset_0, dataset_1])
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr"
)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
for item in aggr_ds:
pass