Refactor OpenX (#505)

This commit is contained in:
Michel Aractingi
2024-12-03 00:51:55 +01:00
committed by GitHub
parent 32eb0cec8f
commit a2c181992a
6 changed files with 58 additions and 1919 deletions

View File

@@ -66,7 +66,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif "openx_rlds" in raw_format:
elif raw_format in ["rlds", "openx"]:
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
@@ -204,24 +204,14 @@ def push_dataset_to_hub(
# convert dataset from original raw format to LeRobot format
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
fmt_kwgs = {
"raw_dir": raw_dir,
"videos_dir": videos_dir,
"fps": fps,
"video": video,
"episodes": episodes,
"encoding": encoding,
}
if "openx_rlds." in raw_format:
# Support for official OXE dataset name inside `raw_format`.
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
_, openx_dataset_name = raw_format.split(".")
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir,
videos_dir,
fps,
video,
episodes,
encoding,
)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
@@ -290,7 +280,7 @@ def main():
"--raw-format",
type=str,
required=True,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).",
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
)
parser.add_argument(
"--repo-id",