forked from tangger/lerobot
Compare commits
2 Commits
recovered-
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cc2cc896a | ||
|
|
acc433d25d |
282
examples/port_datasets/rlds_openx.py
Normal file
282
examples/port_datasets/rlds_openx.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
def tf_to_torch(data):
|
||||||
|
return torch.from_numpy(data.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def tf_img_convert(img):
|
||||||
|
if img.dtype == tf.string:
|
||||||
|
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
|
||||||
|
elif img.dtype != tf.uint8:
|
||||||
|
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
|
||||||
|
return torch.from_numpy(img.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def get_type(dtype):
|
||||||
|
if dtype == tf.uint8:
|
||||||
|
return "uint8"
|
||||||
|
elif dtype == tf.float32:
|
||||||
|
return "float32"
|
||||||
|
elif dtype == tf.float64:
|
||||||
|
return "float64"
|
||||||
|
elif dtype == tf.bool:
|
||||||
|
return "bool"
|
||||||
|
elif dtype == tf.string:
|
||||||
|
return "str"
|
||||||
|
|
||||||
|
|
||||||
|
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
|
||||||
|
"""
|
||||||
|
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
|
||||||
|
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
|
||||||
|
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
|
||||||
|
|
||||||
|
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
|
||||||
|
"""
|
||||||
|
steps = traj.pop("steps")
|
||||||
|
|
||||||
|
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
|
||||||
|
|
||||||
|
# broadcast metadata to the length of the trajectory
|
||||||
|
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
|
||||||
|
|
||||||
|
# put steps back in
|
||||||
|
assert "traj_metadata" not in steps
|
||||||
|
traj = {**steps, "traj_metadata": metadata}
|
||||||
|
|
||||||
|
assert "_len" not in traj
|
||||||
|
assert "_traj_index" not in traj
|
||||||
|
assert "_frame_index" not in traj
|
||||||
|
traj["_len"] = tf.repeat(traj_len, traj_len)
|
||||||
|
traj["_traj_index"] = tf.repeat(i, traj_len)
|
||||||
|
traj["_frame_index"] = tf.range(traj_len)
|
||||||
|
|
||||||
|
return traj
|
||||||
|
|
||||||
|
|
||||||
|
def load_raw_dataset(path: Path):
|
||||||
|
ds_builder = tfds.builder_from_directory(str(path))
|
||||||
|
dataset = ds_builder.as_dataset(
|
||||||
|
split="all",
|
||||||
|
decoders={"steps": tfds.decode.SkipDecoding()},
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_info = ds_builder.info
|
||||||
|
print("dataset_info: ", dataset_info)
|
||||||
|
ds_length = len(dataset)
|
||||||
|
dataset = dataset.take(ds_length)
|
||||||
|
# "flatten" the dataset as such we can apply trajectory level map() easily
|
||||||
|
# each [obs][key] has a shape of (frame_size, ...)
|
||||||
|
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
|
||||||
|
|
||||||
|
return dataset, dataset_info
|
||||||
|
|
||||||
|
|
||||||
|
def build_features_and_dataset_keys(dataset_info):
|
||||||
|
features = {}
|
||||||
|
image_keys = []
|
||||||
|
state_keys = []
|
||||||
|
other_keys = []
|
||||||
|
for key, data_info in dataset_info.features["steps"].items():
|
||||||
|
if "observation" in key:
|
||||||
|
# check whether the key is for an image or a vector observation
|
||||||
|
# only add rgb images, discard depth
|
||||||
|
for k, info in data_info.items():
|
||||||
|
if len(info.shape) == 3 and info.dtype == tf.uint8:
|
||||||
|
image_keys.append(k)
|
||||||
|
dtype = "video"
|
||||||
|
shape = info.shape
|
||||||
|
# TODO (michel_aractingi) add info[key].doc for feature description
|
||||||
|
features["observation.image." + k] = {"dtype": dtype, "shape": shape, "name": None}
|
||||||
|
else:
|
||||||
|
state_keys.append(k)
|
||||||
|
dtype = get_type(info.dtype)
|
||||||
|
shape = info.shape
|
||||||
|
# TODO (michel_aractingi) add info[key].doc for feature description
|
||||||
|
features["observation.state." + k] = {"dtype": dtype, "shape": shape, "name": None}
|
||||||
|
else:
|
||||||
|
if type(data_info) is tfds.features.Tensor:
|
||||||
|
# TODO extend features to take language instructions
|
||||||
|
if "language_instruction" in key:
|
||||||
|
continue
|
||||||
|
other_keys.append(key)
|
||||||
|
dtype = get_type(data_info.dtype)
|
||||||
|
shape = data_info.shape
|
||||||
|
if len(shape) == 0:
|
||||||
|
shape = (1,)
|
||||||
|
if key == "is_last":
|
||||||
|
features["next.done"] = {"dtype": dtype, "shape": shape, "name": None}
|
||||||
|
elif key == "reward":
|
||||||
|
features["next.reward"] = {"dtype": dtype, "shape": shape, "name": None}
|
||||||
|
else:
|
||||||
|
features[key] = {"dtype": dtype, "shape": shape, "name": None}
|
||||||
|
# elif type(data_info) is tfds.features.FeaturesDict: TODO add dictionary based variables
|
||||||
|
|
||||||
|
return features, image_keys, state_keys, other_keys
|
||||||
|
|
||||||
|
|
||||||
|
def to_lerobotdataset_with_save_episode(raw_dir: Path, repo_id: str, push_to_hub: bool = True, fps=30):
|
||||||
|
if (LEROBOT_HOME / repo_id).exists():
|
||||||
|
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
|
dataset, dataset_info = load_raw_dataset(path=raw_dir)
|
||||||
|
|
||||||
|
# Build features
|
||||||
|
features, image_keys, state_keys, other_keys = build_features_and_dataset_keys(dataset_info)
|
||||||
|
|
||||||
|
lerobot_dataset = LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=fps,
|
||||||
|
features=features,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
it = iter(dataset)
|
||||||
|
lang_instruction = None
|
||||||
|
# The iterator it loops over each EPISODE in dataset (not frame-by-frame)
|
||||||
|
# len(dataset) is the number of trajectories/episodes in a dataset
|
||||||
|
for ep_idx in tqdm.tqdm(range(len(dataset))):
|
||||||
|
episode = next(it)
|
||||||
|
episode_data = {}
|
||||||
|
num_frames = episode["action"].shape[0]
|
||||||
|
lang_instruction = episode["language_instruction"].numpy()[0].decode("utf-8")
|
||||||
|
|
||||||
|
for key in state_keys:
|
||||||
|
episode_data["observation.state." + key] = tf_to_torch(episode["observation"][key])
|
||||||
|
for key in image_keys:
|
||||||
|
decoded_images = [tf_img_convert(img) for img in episode["observation"][key]]
|
||||||
|
episode_data["observation.image." + key] = decoded_images
|
||||||
|
|
||||||
|
for key in other_keys:
|
||||||
|
if "language_instruction" in key:
|
||||||
|
# Some openx dataset have multiple language commands
|
||||||
|
episode_data[key] = episode[key].numpy()[0].decode("utf-8")
|
||||||
|
else:
|
||||||
|
if key == "is_last":
|
||||||
|
episode_data["next.done"] = tf_to_torch(episode[key])
|
||||||
|
elif key == "reward":
|
||||||
|
episode_data["next.reward"] = tf_to_torch(episode[key])
|
||||||
|
else:
|
||||||
|
episode_data[key] = tf_to_torch(episode[key])
|
||||||
|
|
||||||
|
episode_data["size"] = num_frames
|
||||||
|
episode_data["episode_index"] = ep_idx # torch.tensor([ep_idx] * num_frames)
|
||||||
|
episode_data["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
episode_data["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
episode_data["task_index"] = 0 # TODO calculate task index correctly
|
||||||
|
episode_data["index"] = 0 # TODO figure out what index is for in DEFAULT_FEATURES
|
||||||
|
|
||||||
|
lerobot_dataset.save_episode(task=lang_instruction, episode_data=episode_data)
|
||||||
|
|
||||||
|
lerobot_dataset.consolidate()
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
lerobot_dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
def to_lerobotdataset_with_add_frame(raw_dir: Path, repo_id: str, push_to_hub: bool = True, fps=30):
|
||||||
|
if (LEROBOT_HOME / repo_id).exists():
|
||||||
|
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
|
dataset, dataset_info = load_raw_dataset(path=raw_dir)
|
||||||
|
|
||||||
|
# Build features, get keys
|
||||||
|
features, image_keys, state_keys, other_keys = build_features_and_dataset_keys(dataset_info)
|
||||||
|
|
||||||
|
lerobot_dataset = LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=fps,
|
||||||
|
features=features,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
it = iter(dataset)
|
||||||
|
lang_instruction = None
|
||||||
|
# The iterator it loops over each EPISODE in dataset (not frame-by-frame)
|
||||||
|
# len(dataset) is the number of trajectories/episodes in a dataset
|
||||||
|
for _ep_idx in tqdm.tqdm(range(len(dataset))):
|
||||||
|
episode = next(it)
|
||||||
|
episode_data = {}
|
||||||
|
num_frames = episode["action"].shape[0]
|
||||||
|
lang_instruction = episode["language_instruction"].numpy()[0].decode("utf-8")
|
||||||
|
|
||||||
|
for key in state_keys:
|
||||||
|
episode_data["observation.state." + key] = tf_to_torch(episode["observation"][key])
|
||||||
|
for key in image_keys:
|
||||||
|
decoded_images = [tf_img_convert(img) for img in episode["observation"][key]]
|
||||||
|
episode_data["observation.image." + key] = decoded_images
|
||||||
|
|
||||||
|
for key in other_keys:
|
||||||
|
if "language_instruction" in key:
|
||||||
|
# Some openx dataset have multiple language commands
|
||||||
|
# like droid has 1-3 language instructions for some trajectories
|
||||||
|
episode_data[key] = episode[key].numpy()[0].decode("utf-8")
|
||||||
|
else:
|
||||||
|
if key == "is_last":
|
||||||
|
episode_data["next.done"] = tf_to_torch(episode[key])
|
||||||
|
elif key == "reward":
|
||||||
|
episode_data["next.reward"] = tf_to_torch(episode[key])
|
||||||
|
else:
|
||||||
|
episode_data[key] = tf_to_torch(episode[key])
|
||||||
|
|
||||||
|
for i in range(num_frames):
|
||||||
|
frame = {}
|
||||||
|
for key in episode_data:
|
||||||
|
if "language_instruction" in key:
|
||||||
|
frame[key] = episode_data[key]
|
||||||
|
else:
|
||||||
|
frame[key] = episode_data[key][i]
|
||||||
|
|
||||||
|
lerobot_dataset.add_frame(frame)
|
||||||
|
|
||||||
|
lerobot_dataset.save_episode(task=lang_instruction)
|
||||||
|
|
||||||
|
lerobot_dataset.consolidate()
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
lerobot_dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the directory of the raw dataset in rlds/openx format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Binary value to indicate whether you want to push the dataset to the HuggingFace Hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fps",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="frames per second, can be found the openx spreadsheet for openx datasets."
|
||||||
|
"https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
to_lerobotdataset_with_add_frame(args.raw_dir, args.repo_id, args.push_to_hub, args.fps)
|
||||||
@@ -26,7 +26,6 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
|||||||
|
|
||||||
Note: We assume the images are in channel first format
|
Note: We assume the images are in channel first format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
|||||||
@@ -744,9 +744,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||||
time for video encoding.
|
time for video encoding.
|
||||||
"""
|
"""
|
||||||
if not episode_data:
|
|
||||||
episode_buffer = self.episode_buffer
|
|
||||||
|
|
||||||
|
episode_buffer = episode_data if episode_data else self.episode_buffer
|
||||||
episode_length = episode_buffer.pop("size")
|
episode_length = episode_buffer.pop("size")
|
||||||
episode_index = episode_buffer["episode_index"]
|
episode_index = episode_buffer["episode_index"]
|
||||||
if episode_index != self.meta.total_episodes:
|
if episode_index != self.meta.total_episodes:
|
||||||
@@ -762,7 +761,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
task_index = self.meta.get_task_index(task)
|
task_index = self.meta.get_task_index(task)
|
||||||
|
|
||||||
if not set(episode_buffer.keys()) == set(self.features):
|
if not set(episode_buffer.keys()) == set(self.features):
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
@@ -775,7 +774,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episode_buffer[key] = np.full((episode_length,), episode_index)
|
episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||||
elif key == "task_index":
|
elif key == "task_index":
|
||||||
episode_buffer[key] = np.full((episode_length,), task_index)
|
episode_buffer[key] = np.full((episode_length,), task_index)
|
||||||
elif ft["dtype"] in ["image", "video"]:
|
elif ft["dtype"] in ["image", "video"] or "language_instruction" in key:
|
||||||
continue
|
continue
|
||||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||||
@@ -896,7 +895,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
self.stop_image_writer()
|
self.stop_image_writer()
|
||||||
# TODO(aliberts): refactor stats in save_episodes
|
# TODO(aliberts): refactor stats in save_episodes
|
||||||
self.meta.stats = compute_stats(self)
|
self.meta.stats = compute_stats(self, num_workers=0)
|
||||||
serialized_stats = serialize_dict(self.meta.stats)
|
serialized_stats = serialize_dict(self.meta.stats)
|
||||||
write_json(serialized_stats, self.root / STATS_PATH)
|
write_json(serialized_stats, self.root / STATS_PATH)
|
||||||
self.consolidated = True
|
self.consolidated = True
|
||||||
|
|||||||
Reference in New Issue
Block a user