diff --git a/examples/port_datasets/droid_rlds/port_droid.py b/examples/port_datasets/droid_rlds/port_droid.py index f5c903fb2..3bcb0ba36 100644 --- a/examples/port_datasets/droid_rlds/port_droid.py +++ b/examples/port_datasets/droid_rlds/port_droid.py @@ -22,7 +22,7 @@ from pathlib import Path import numpy as np import tensorflow_datasets as tfds -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds DROID_SHARDS = 2048 @@ -370,6 +370,25 @@ def port_droid( ) +def validate_dataset(repo_id): + """Sanity check that ensure meta data can be loaded and all files are present.""" + meta = LeRobotDatasetMetadata(repo_id) + + if meta.total_episodes == 0: + raise ValueError("Number of episodes is 0.") + + for ep_idx in range(meta.total_episodes): + data_path = meta.root / meta.get_data_file_path(ep_idx) + + if not data_path.exists(): + raise ValueError(f"Parquet file is missing in: {data_path}") + + for vid_key in meta.video_keys: + vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key) + if not vid_path.exists(): + raise ValueError(f"Video file is missing in: {vid_path}") + + def main(): parser = argparse.ArgumentParser() diff --git a/examples/port_datasets/droid_rlds/slurm_port_shards.py b/examples/port_datasets/droid_rlds/slurm_port_shards.py index 08e36bc39..7a1e8dd2b 100644 --- a/examples/port_datasets/droid_rlds/slurm_port_shards.py +++ b/examples/port_datasets/droid_rlds/slurm_port_shards.py @@ -9,25 +9,6 @@ from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata -def validate_shard(repo_id): - """Sanity check that ensure meta data can be loaded and all files are present.""" - meta = LeRobotDatasetMetadata(repo_id) - - if meta.total_episodes == 0: - raise ValueError("Number of episodes is 0.") - - for ep_idx in range(meta.total_episodes): - data_path = meta.root / meta.get_data_file_path(ep_idx) - - if not data_path.exists(): - raise ValueError(f"Parquet file is missing in: {data_path}") - - for vid_key in meta.video_keys: - vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key) - if not vid_path.exists(): - raise ValueError(f"Video file is missing in: {vid_path}") - - class PortDroidShards(PipelineStep): def __init__( self, @@ -41,7 +22,7 @@ class PortDroidShards(PipelineStep): def run(self, data=None, rank: int = 0, world_size: int = 1): from datasets.utils.tqdm import disable_progress_bars - from examples.port_datasets.droid_rlds.port_droid import port_droid + from examples.port_datasets.droid_rlds.port_droid import port_droid, validate_dataset from lerobot.common.utils.utils import init_logging init_logging() @@ -57,7 +38,7 @@ class PortDroidShards(PipelineStep): shard_index=rank, ) - validate_shard(shard_repo_id) + validate_dataset(shard_repo_id) def make_port_executor( diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 6edf46523..58dd94007 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -126,9 +126,8 @@ def load_nested_dataset(pq_dir: Path) -> Dataset: raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") # TODO(rcadene): set num_proc to accelerate conversion to pyarrow - return concatenate_datasets( - [Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))] - ) + datasets = [Dataset.from_parquet(str(path)) for path in paths] + return concatenate_datasets(datasets) def get_parquet_num_frames(parquet_path):