forked from tangger/lerobot
WIP
This commit is contained in:
@@ -17,37 +17,35 @@
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||
NOTE: Install `tensorflow` and `tensorflow_datasets` before running this script.
|
||||
```bash
|
||||
pip install tensorflow
|
||||
pip install tensorflow_datasets
|
||||
```
|
||||
|
||||
Example:
|
||||
python openx_rlds.py \
|
||||
--raw-dir /path/to/bridge_orig/1.0.0 \
|
||||
--local-dir /path/to/local_dir \
|
||||
--repo-id your_id \
|
||||
--use-videos \
|
||||
--push-to-hub
|
||||
```bash
|
||||
python examples/port_datasets/openx_rlds.py \
|
||||
--raw-dir /fsx/mustafa_shukor/droid \
|
||||
--repo-id cadene/droid \
|
||||
--use-videos \
|
||||
--push-to-hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
oxe_utils_dir = os.path.join(current_dir, "oxe_utils")
|
||||
sys.path.append(oxe_utils_dir)
|
||||
|
||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
from examples.port_datasets.openx_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||
from examples.port_datasets.openx_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
@@ -87,16 +85,23 @@ def transform_raw_dataset(episode, dataset_name):
|
||||
return episode
|
||||
|
||||
|
||||
def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bool = True):
|
||||
dataset_name = builder.name
|
||||
|
||||
def generate_features_from_raw(dataset_name: str, builder: tfds.core.DatasetBuilder, use_videos: bool = True):
|
||||
state_names = [f"motor_{i}" for i in range(8)]
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
|
||||
if state_encoding == StateEncoding.POS_EULER:
|
||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
|
||||
if "libero" in dataset_name:
|
||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
|
||||
state_names = [
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"roll",
|
||||
"pitch",
|
||||
"yaw",
|
||||
"gripper",
|
||||
"gripper",
|
||||
] # 2D gripper state
|
||||
elif state_encoding == StateEncoding.POS_QUAT:
|
||||
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
||||
|
||||
@@ -126,44 +131,68 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
|
||||
return {**features, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.data.Dataset, **kwargs):
|
||||
for episode in raw_dataset.as_numpy_iterator():
|
||||
def save_as_lerobot_dataset(
|
||||
dataset_name: str,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
raw_dataset: tf.data.Dataset,
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
total_num_episodes = raw_dataset.cardinality().numpy().item()
|
||||
print(f"Total number of episodes {total_num_episodes}")
|
||||
|
||||
if num_shards is not None:
|
||||
num_shards = 10000
|
||||
shard_index = 9999
|
||||
sharded_dataset = raw_dataset.shard(num_shards=num_shards, index=shard_index)
|
||||
sharded_num_episodes = sharded_dataset.cardinality().numpy().item()
|
||||
print(f"{sharded_num_episodes=}")
|
||||
num_episodes = sharded_num_episodes
|
||||
iter_ = iter(sharded_dataset)
|
||||
else:
|
||||
num_episodes = total_num_episodes
|
||||
iter_ = iter(raw_dataset)
|
||||
|
||||
for episode_index in range(num_episodes):
|
||||
print(f"{episode_index} / {num_episodes}")
|
||||
episode = next(iter_)
|
||||
print("\nnext\n")
|
||||
episode = transform_raw_dataset(episode, dataset_name)
|
||||
|
||||
traj = episode["steps"]
|
||||
for i in range(traj["action"].shape[0]):
|
||||
for i in tqdm.tqdm(range(traj["action"].shape[0])):
|
||||
image_dict = {
|
||||
f"observation.images.{key}": value[i]
|
||||
f"observation.images.{key}": value[i].numpy()
|
||||
for key, value in traj["observation"].items()
|
||||
if "depth" not in key and any(x in key for x in ["image", "rgb"])
|
||||
}
|
||||
lerobot_dataset.add_frame(
|
||||
{
|
||||
**image_dict,
|
||||
"observation.state": traj["proprio"][i],
|
||||
"action": traj["action"][i],
|
||||
"observation.state": traj["proprio"][i].numpy(),
|
||||
"action": traj["action"][i].numpy(),
|
||||
"task": traj["task"][i].numpy().decode(),
|
||||
}
|
||||
)
|
||||
lerobot_dataset.save_episode(task=traj["task"][0].decode())
|
||||
|
||||
lerobot_dataset.consolidate(
|
||||
run_compute_stats=True,
|
||||
keep_image_files=kwargs["keep_images"],
|
||||
stat_kwargs={"batch_size": kwargs["batch_size"], "num_workers": kwargs["num_workers"]},
|
||||
)
|
||||
print()
|
||||
lerobot_dataset.save_episode()
|
||||
print("\nsave_episode\n")
|
||||
|
||||
break
|
||||
|
||||
|
||||
def create_lerobot_dataset(
|
||||
raw_dir: Path,
|
||||
repo_id: str = None,
|
||||
local_dir: Path = None,
|
||||
push_to_hub: bool = False,
|
||||
fps: int = None,
|
||||
robot_type: str = None,
|
||||
use_videos: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
image_writer_process: int = 5,
|
||||
image_writer_threads: int = 10,
|
||||
keep_images: bool = True,
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
last_part = raw_dir.name
|
||||
if re.match(r"^\d+\.\d+\.\d+$", last_part):
|
||||
@@ -175,15 +204,9 @@ def create_lerobot_dataset(
|
||||
dataset_name = last_part
|
||||
data_dir = raw_dir.parent
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = Path(LEROBOT_HOME)
|
||||
local_dir /= f"{dataset_name}_{version}_lerobot"
|
||||
if local_dir.exists():
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
||||
features = generate_features_from_raw(builder, use_videos)
|
||||
raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name))
|
||||
features = generate_features_from_raw(dataset_name, builder, use_videos)
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
if fps is None:
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
@@ -201,7 +224,6 @@ def create_lerobot_dataset(
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
robot_type=robot_type,
|
||||
root=local_dir,
|
||||
fps=fps,
|
||||
use_videos=use_videos,
|
||||
features=features,
|
||||
@@ -210,16 +232,18 @@ def create_lerobot_dataset(
|
||||
)
|
||||
|
||||
save_as_lerobot_dataset(
|
||||
lerobot_dataset, raw_dataset, keep_images=keep_images, batch_size=batch_size, num_workers=num_workers
|
||||
dataset_name,
|
||||
lerobot_dataset,
|
||||
raw_dataset,
|
||||
num_shards=num_shards,
|
||||
shard_index=shard_index,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
assert repo_id is not None
|
||||
tags = ["LeRobot", dataset_name, "rlds"]
|
||||
tags = []
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
tags.append("openx")
|
||||
if robot_type != "unknown":
|
||||
tags.append(robot_type)
|
||||
lerobot_dataset.push_to_hub(
|
||||
tags=tags,
|
||||
private=False,
|
||||
@@ -237,12 +261,6 @@ def main():
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
@@ -270,37 +288,25 @@ def main():
|
||||
action="store_true",
|
||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-writer-process",
|
||||
type=int,
|
||||
default=5,
|
||||
default=0,
|
||||
help="Number of processes of image writer for saving images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-writer-threads",
|
||||
type=int,
|
||||
default=10,
|
||||
default=8,
|
||||
help="Number of threads per process of image writer for saving images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-images",
|
||||
action="store_true",
|
||||
help="Whether to keep the cached images.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
droid_dir = Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene/droid")
|
||||
if droid_dir.exists():
|
||||
shutil.rmtree(droid_dir)
|
||||
|
||||
create_lerobot_dataset(**vars(args))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user