Compare commits

...

2 Commits

Author SHA1 Message Date
Michel Aractingi
6cc2cc896a set num workers to zero in compute stats, temporary fix to hanging code when iterating over dataloader 2025-01-20 11:02:37 +00:00
Michel Aractingi
acc433d25d Added example script to port rlds/openx dataset formats 2025-01-18 10:49:26 +00:00
3 changed files with 286 additions and 6 deletions

View File

@@ -0,0 +1,282 @@
import argparse
import shutil
from pathlib import Path
import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
def tf_to_torch(data):
return torch.from_numpy(data.numpy())
def tf_img_convert(img):
if img.dtype == tf.string:
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
elif img.dtype != tf.uint8:
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
return torch.from_numpy(img.numpy())
def get_type(dtype):
if dtype == tf.uint8:
return "uint8"
elif dtype == tf.float32:
return "float32"
elif dtype == tf.float64:
return "float64"
elif dtype == tf.bool:
return "bool"
elif dtype == tf.string:
return "str"
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
"""
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
"""
steps = traj.pop("steps")
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
# broadcast metadata to the length of the trajectory
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
# put steps back in
assert "traj_metadata" not in steps
traj = {**steps, "traj_metadata": metadata}
assert "_len" not in traj
assert "_traj_index" not in traj
assert "_frame_index" not in traj
traj["_len"] = tf.repeat(traj_len, traj_len)
traj["_traj_index"] = tf.repeat(i, traj_len)
traj["_frame_index"] = tf.range(traj_len)
return traj
def load_raw_dataset(path: Path):
ds_builder = tfds.builder_from_directory(str(path))
dataset = ds_builder.as_dataset(
split="all",
decoders={"steps": tfds.decode.SkipDecoding()},
)
dataset_info = ds_builder.info
print("dataset_info: ", dataset_info)
ds_length = len(dataset)
dataset = dataset.take(ds_length)
# "flatten" the dataset as such we can apply trajectory level map() easily
# each [obs][key] has a shape of (frame_size, ...)
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
return dataset, dataset_info
def build_features_and_dataset_keys(dataset_info):
features = {}
image_keys = []
state_keys = []
other_keys = []
for key, data_info in dataset_info.features["steps"].items():
if "observation" in key:
# check whether the key is for an image or a vector observation
# only add rgb images, discard depth
for k, info in data_info.items():
if len(info.shape) == 3 and info.dtype == tf.uint8:
image_keys.append(k)
dtype = "video"
shape = info.shape
# TODO (michel_aractingi) add info[key].doc for feature description
features["observation.image." + k] = {"dtype": dtype, "shape": shape, "name": None}
else:
state_keys.append(k)
dtype = get_type(info.dtype)
shape = info.shape
# TODO (michel_aractingi) add info[key].doc for feature description
features["observation.state." + k] = {"dtype": dtype, "shape": shape, "name": None}
else:
if type(data_info) is tfds.features.Tensor:
# TODO extend features to take language instructions
if "language_instruction" in key:
continue
other_keys.append(key)
dtype = get_type(data_info.dtype)
shape = data_info.shape
if len(shape) == 0:
shape = (1,)
if key == "is_last":
features["next.done"] = {"dtype": dtype, "shape": shape, "name": None}
elif key == "reward":
features["next.reward"] = {"dtype": dtype, "shape": shape, "name": None}
else:
features[key] = {"dtype": dtype, "shape": shape, "name": None}
# elif type(data_info) is tfds.features.FeaturesDict: TODO add dictionary based variables
return features, image_keys, state_keys, other_keys
def to_lerobotdataset_with_save_episode(raw_dir: Path, repo_id: str, push_to_hub: bool = True, fps=30):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
dataset, dataset_info = load_raw_dataset(path=raw_dir)
# Build features
features, image_keys, state_keys, other_keys = build_features_and_dataset_keys(dataset_info)
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps,
features=features,
use_videos=True,
image_writer_threads=4,
)
it = iter(dataset)
lang_instruction = None
# The iterator it loops over each EPISODE in dataset (not frame-by-frame)
# len(dataset) is the number of trajectories/episodes in a dataset
for ep_idx in tqdm.tqdm(range(len(dataset))):
episode = next(it)
episode_data = {}
num_frames = episode["action"].shape[0]
lang_instruction = episode["language_instruction"].numpy()[0].decode("utf-8")
for key in state_keys:
episode_data["observation.state." + key] = tf_to_torch(episode["observation"][key])
for key in image_keys:
decoded_images = [tf_img_convert(img) for img in episode["observation"][key]]
episode_data["observation.image." + key] = decoded_images
for key in other_keys:
if "language_instruction" in key:
# Some openx dataset have multiple language commands
episode_data[key] = episode[key].numpy()[0].decode("utf-8")
else:
if key == "is_last":
episode_data["next.done"] = tf_to_torch(episode[key])
elif key == "reward":
episode_data["next.reward"] = tf_to_torch(episode[key])
else:
episode_data[key] = tf_to_torch(episode[key])
episode_data["size"] = num_frames
episode_data["episode_index"] = ep_idx # torch.tensor([ep_idx] * num_frames)
episode_data["frame_index"] = torch.arange(0, num_frames, 1)
episode_data["timestamp"] = torch.arange(0, num_frames, 1) / fps
episode_data["task_index"] = 0 # TODO calculate task index correctly
episode_data["index"] = 0 # TODO figure out what index is for in DEFAULT_FEATURES
lerobot_dataset.save_episode(task=lang_instruction, episode_data=episode_data)
lerobot_dataset.consolidate()
if push_to_hub:
lerobot_dataset.push_to_hub()
def to_lerobotdataset_with_add_frame(raw_dir: Path, repo_id: str, push_to_hub: bool = True, fps=30):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
dataset, dataset_info = load_raw_dataset(path=raw_dir)
# Build features, get keys
features, image_keys, state_keys, other_keys = build_features_and_dataset_keys(dataset_info)
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=fps,
features=features,
use_videos=True,
image_writer_threads=4,
)
it = iter(dataset)
lang_instruction = None
# The iterator it loops over each EPISODE in dataset (not frame-by-frame)
# len(dataset) is the number of trajectories/episodes in a dataset
for _ep_idx in tqdm.tqdm(range(len(dataset))):
episode = next(it)
episode_data = {}
num_frames = episode["action"].shape[0]
lang_instruction = episode["language_instruction"].numpy()[0].decode("utf-8")
for key in state_keys:
episode_data["observation.state." + key] = tf_to_torch(episode["observation"][key])
for key in image_keys:
decoded_images = [tf_img_convert(img) for img in episode["observation"][key]]
episode_data["observation.image." + key] = decoded_images
for key in other_keys:
if "language_instruction" in key:
# Some openx dataset have multiple language commands
# like droid has 1-3 language instructions for some trajectories
episode_data[key] = episode[key].numpy()[0].decode("utf-8")
else:
if key == "is_last":
episode_data["next.done"] = tf_to_torch(episode[key])
elif key == "reward":
episode_data["next.reward"] = tf_to_torch(episode[key])
else:
episode_data[key] = tf_to_torch(episode[key])
for i in range(num_frames):
frame = {}
for key in episode_data:
if "language_instruction" in key:
frame[key] = episode_data[key]
else:
frame[key] = episode_data[key][i]
lerobot_dataset.add_frame(frame)
lerobot_dataset.save_episode(task=lang_instruction)
lerobot_dataset.consolidate()
if push_to_hub:
lerobot_dataset.push_to_hub()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser.add_argument(
"--raw-dir",
type=str,
required=True,
help="Path to the directory of the raw dataset in rlds/openx format.",
)
parser.add_argument(
"--push-to-hub",
type=int,
default=0,
help="Binary value to indicate whether you want to push the dataset to the HuggingFace Hub.",
)
parser.add_argument(
"--fps",
type=int,
default=30,
help="frames per second, can be found the openx spreadsheet for openx datasets."
"https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0",
)
args = parser.parse_args()
to_lerobotdataset_with_add_frame(args.raw_dir, args.repo_id, args.push_to_hub, args.fps)

View File

@@ -26,7 +26,6 @@ def get_stats_einops_patterns(dataset, num_workers=0):
Note: We assume the images are in channel first format Note: We assume the images are in channel first format
""" """
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,

View File

@@ -744,9 +744,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding. time for video encoding.
""" """
if not episode_data:
episode_buffer = self.episode_buffer
episode_buffer = episode_data if episode_data else self.episode_buffer
episode_length = episode_buffer.pop("size") episode_length = episode_buffer.pop("size")
episode_index = episode_buffer["episode_index"] episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes: if episode_index != self.meta.total_episodes:
@@ -762,7 +761,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
task_index = self.meta.get_task_index(task) task_index = self.meta.get_task_index(task)
if not set(episode_buffer.keys()) == set(self.features): if not set(episode_buffer.keys()) == set(self.features):
raise ValueError() raise ValueError()
@@ -775,7 +774,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_buffer[key] = np.full((episode_length,), episode_index) episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index": elif key == "task_index":
episode_buffer[key] = np.full((episode_length,), task_index) episode_buffer[key] = np.full((episode_length,), task_index)
elif ft["dtype"] in ["image", "video"]: elif ft["dtype"] in ["image", "video"] or "language_instruction" in key:
continue continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"]) episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
@@ -896,7 +895,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if run_compute_stats: if run_compute_stats:
self.stop_image_writer() self.stop_image_writer()
# TODO(aliberts): refactor stats in save_episodes # TODO(aliberts): refactor stats in save_episodes
self.meta.stats = compute_stats(self) self.meta.stats = compute_stats(self, num_workers=0)
serialized_stats = serialize_dict(self.meta.stats) serialized_stats = serialize_dict(self.meta.stats)
write_json(serialized_stats, self.root / STATS_PATH) write_json(serialized_stats, self.root / STATS_PATH)
self.consolidated = True self.consolidated = True