Support for converting OpenX datasets from RLDS format to LeRobotDataset (#354)

Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: youliangtan <tan_you_liang@hotmail.com>
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Michel Aractingi
2024-08-27 09:07:00 +02:00
committed by GitHub
parent aad59e6b6b
commit eb4c505cff
12 changed files with 2329 additions and 6 deletions

View File

@@ -66,6 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif "openx_rlds" in raw_format:
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
@@ -197,9 +199,25 @@ def push_dataset_to_hub(
# convert dataset from original raw format to LeRobot format
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps, video, episodes, encoding
)
fmt_kwgs = {
"raw_dir": raw_dir,
"videos_dir": videos_dir,
"fps": fps,
"video": video,
"episodes": episodes,
"encoding": encoding,
}
if "openx_rlds." in raw_format:
# Support for official OXE dataset name inside `raw_format`.
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
_, openx_dataset_name = raw_format.split(".")
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
@@ -268,7 +286,7 @@ def main():
"--raw-format",
type=str,
required=True,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).",
)
parser.add_argument(
"--repo-id",
@@ -328,6 +346,13 @@ def main():
default=0,
help="When set to 1, resumes a previous run.",
)
parser.add_argument(
"--cache-dir",
type=Path,
required=False,
default="/tmp",
help="Directory to store the temporary videos and images generated while creating the dataset.",
)
parser.add_argument(
"--tests-data-dir",
type=Path,