Uploaded droid 1.0.1
This commit is contained in:
@@ -22,7 +22,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||||
|
|
||||||
DROID_SHARDS = 2048
|
DROID_SHARDS = 2048
|
||||||
@@ -370,6 +370,25 @@ def port_droid(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_dataset(repo_id):
|
||||||
|
"""Sanity check that ensure meta data can be loaded and all files are present."""
|
||||||
|
meta = LeRobotDatasetMetadata(repo_id)
|
||||||
|
|
||||||
|
if meta.total_episodes == 0:
|
||||||
|
raise ValueError("Number of episodes is 0.")
|
||||||
|
|
||||||
|
for ep_idx in range(meta.total_episodes):
|
||||||
|
data_path = meta.root / meta.get_data_file_path(ep_idx)
|
||||||
|
|
||||||
|
if not data_path.exists():
|
||||||
|
raise ValueError(f"Parquet file is missing in: {data_path}")
|
||||||
|
|
||||||
|
for vid_key in meta.video_keys:
|
||||||
|
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
if not vid_path.exists():
|
||||||
|
raise ValueError(f"Video file is missing in: {vid_path}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
|||||||
@@ -9,25 +9,6 @@ from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
|||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
|
||||||
|
|
||||||
def validate_shard(repo_id):
|
|
||||||
"""Sanity check that ensure meta data can be loaded and all files are present."""
|
|
||||||
meta = LeRobotDatasetMetadata(repo_id)
|
|
||||||
|
|
||||||
if meta.total_episodes == 0:
|
|
||||||
raise ValueError("Number of episodes is 0.")
|
|
||||||
|
|
||||||
for ep_idx in range(meta.total_episodes):
|
|
||||||
data_path = meta.root / meta.get_data_file_path(ep_idx)
|
|
||||||
|
|
||||||
if not data_path.exists():
|
|
||||||
raise ValueError(f"Parquet file is missing in: {data_path}")
|
|
||||||
|
|
||||||
for vid_key in meta.video_keys:
|
|
||||||
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
|
|
||||||
if not vid_path.exists():
|
|
||||||
raise ValueError(f"Video file is missing in: {vid_path}")
|
|
||||||
|
|
||||||
|
|
||||||
class PortDroidShards(PipelineStep):
|
class PortDroidShards(PipelineStep):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -41,7 +22,7 @@ class PortDroidShards(PipelineStep):
|
|||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
from datasets.utils.tqdm import disable_progress_bars
|
from datasets.utils.tqdm import disable_progress_bars
|
||||||
|
|
||||||
from examples.port_datasets.droid_rlds.port_droid import port_droid
|
from examples.port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
@@ -57,7 +38,7 @@ class PortDroidShards(PipelineStep):
|
|||||||
shard_index=rank,
|
shard_index=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
validate_shard(shard_repo_id)
|
validate_dataset(shard_repo_id)
|
||||||
|
|
||||||
|
|
||||||
def make_port_executor(
|
def make_port_executor(
|
||||||
|
|||||||
@@ -126,9 +126,8 @@ def load_nested_dataset(pq_dir: Path) -> Dataset:
|
|||||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||||
|
|
||||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||||
return concatenate_datasets(
|
datasets = [Dataset.from_parquet(str(path)) for path in paths]
|
||||||
[Dataset.from_parquet(str(path)) for path in sorted(pq_dir.glob("*/*.parquet"))]
|
return concatenate_datasets(datasets)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_num_frames(parquet_path):
|
def get_parquet_num_frames(parquet_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user