Improve push_dataset_to_hub API + Add unit tests (#231)

Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Thomas Wolf
2024-06-13 15:18:02 +02:00
committed by GitHub
parent c38f535c9f
commit 125bd93e29
11 changed files with 750 additions and 419 deletions

View File

@@ -14,156 +14,119 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets.
This file contains download scripts for raw datasets.
Example of usage:
```
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
--raw-dir data/cadene/pusht_raw \
--repo-id cadene/pusht_raw
```
"""
import io
import argparse
import logging
import shutil
import warnings
from pathlib import Path
import tqdm
from huggingface_hub import snapshot_download
def download_raw(raw_dir, dataset_id):
if "aloha" in dataset_id or "image" in dataset_id:
download_hub(raw_dir, dataset_id)
elif "pusht" in dataset_id:
download_pusht(raw_dir)
elif "xarm" in dataset_id:
download_xarm(raw_dir)
elif "umi" in dataset_id:
download_umi(raw_dir)
else:
raise ValueError(dataset_id)
def download_raw(raw_dir: Path, repo_id: str):
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile
import requests
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
def download_pusht(raw_dir: str):
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
if not dataset_id.endswith("_raw"):
warnings.warn(
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
stacklevel=1,
)
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(pusht_url, raw_dir)
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
shutil.rmtree(raw_dir / "pusht")
def download_xarm(raw_dir: Path):
"""Download all xarm datasets at once"""
import zipfile
import gdown
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for pkl_path in zip_f.namelist():
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
zip_f.extract(member=pkl_path)
# move to corresponding raw directory
extract_dir = pkl_path.replace("/buffer.pkl", "")
raw_pkl_path = raw_dir / "buffer.pkl"
shutil.move(pkl_path, raw_pkl_path)
shutil.rmtree(extract_dir)
zip_path.unlink()
def download_hub(raw_dir: Path, dataset_id: str):
raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
raw_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
def download_umi(raw_dir: Path):
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
zarr_path = raw_dir / "cup_in_the_wild.zarr"
def download_all_raw_datasets():
data_dir = Path("data")
repo_ids = [
"cadene/pusht_image_raw",
"cadene/xarm_lift_medium_image_raw",
"cadene/xarm_lift_medium_replay_image_raw",
"cadene/xarm_push_medium_image_raw",
"cadene/xarm_push_medium_replay_image_raw",
"cadene/aloha_sim_insertion_human_image_raw",
"cadene/aloha_sim_insertion_scripted_image_raw",
"cadene/aloha_sim_transfer_cube_human_image_raw",
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
"cadene/pusht_raw",
"cadene/xarm_lift_medium_raw",
"cadene/xarm_lift_medium_replay_raw",
"cadene/xarm_push_medium_raw",
"cadene/xarm_push_medium_replay_raw",
"cadene/aloha_sim_insertion_human_raw",
"cadene/aloha_sim_insertion_scripted_raw",
"cadene/aloha_sim_transfer_cube_human_raw",
"cadene/aloha_sim_transfer_cube_scripted_raw",
"cadene/aloha_mobile_cabinet_raw",
"cadene/aloha_mobile_chair_raw",
"cadene/aloha_mobile_elevator_raw",
"cadene/aloha_mobile_shrimp_raw",
"cadene/aloha_mobile_wash_pan_raw",
"cadene/aloha_mobile_wipe_wine_raw",
"cadene/aloha_static_battery_raw",
"cadene/aloha_static_candy_raw",
"cadene/aloha_static_coffee_raw",
"cadene/aloha_static_coffee_new_raw",
"cadene/aloha_static_cups_open_raw",
"cadene/aloha_static_fork_pick_up_raw",
"cadene/aloha_static_pingpong_test_raw",
"cadene/aloha_static_pro_pencil_raw",
"cadene/aloha_static_screw_driver_raw",
"cadene/aloha_static_tape_raw",
"cadene/aloha_static_thread_velcro_raw",
"cadene/aloha_static_towel_raw",
"cadene/aloha_static_vinh_cup_raw",
"cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw",
]
for repo_id in repo_ids:
raw_dir = data_dir / repo_id
download_raw(raw_dir, repo_id)
raw_dir = Path(raw_dir)
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
)
args = parser.parse_args()
download_raw(**vars(args))
if __name__ == "__main__":
data_dir = Path("data")
dataset_ids = [
"pusht_image",
"xarm_lift_medium_image",
"xarm_lift_medium_replay_image",
"xarm_push_medium_image",
"xarm_push_medium_replay_image",
"aloha_sim_insertion_human_image",
"aloha_sim_insertion_scripted_image",
"aloha_sim_transfer_cube_human_image",
"aloha_sim_transfer_cube_scripted_image",
"pusht",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"aloha_mobile_cabinet",
"aloha_mobile_chair",
"aloha_mobile_elevator",
"aloha_mobile_shrimp",
"aloha_mobile_wash_pan",
"aloha_mobile_wipe_wine",
"aloha_static_battery",
"aloha_static_candy",
"aloha_static_coffee",
"aloha_static_coffee_new",
"aloha_static_cups_open",
"aloha_static_fork_pick_up",
"aloha_static_pingpong_test",
"aloha_static_pro_pencil",
"aloha_static_screw_driver",
"aloha_static_tape",
"aloha_static_thread_velcro",
"aloha_static_towel",
"aloha_static_vinh_cup",
"aloha_static_vinh_cup_left",
"aloha_static_ziploc_slide",
"umi_cup_in_the_wild",
]
for dataset_id in dataset_ids:
raw_dir = data_dir / f"{dataset_id}_raw"
download_raw(raw_dir, dataset_id)
main()

View File

@@ -30,6 +30,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool:
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(raw_dir, out_dir, fps, video, debug):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
hdf5_files = list(raw_dir.glob("*.hdf5"))
ep_dicts = []
episode_data_index = {"from": [], "to": []}
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
num_episodes = len(hdf5_files)
id_from = 0
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
ep_dicts = []
ep_ids = episodes if episodes else range(num_episodes)
for ep_idx in tqdm.tqdm(ep_ids):
ep_path = hdf5_files[ep_idx]
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
@@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
@@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
assert isinstance(ep_idx, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
gc.collect()
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video) -> Dataset:
@@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset:
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check
check_format(raw_dir)
if fps is None:
fps = 50
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
hf_dataset = to_hf_dataset(data_dir, video)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,

View File

@@ -17,7 +17,6 @@
Contains utilities to process raw data format from dora-record
"""
import logging
import re
from pathlib import Path
@@ -26,10 +25,10 @@ import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame
from lerobot.common.utils.utils import init_logging
def check_format(raw_dir) -> bool:
@@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
return True
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
@@ -122,8 +121,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
# Create symlink to raw videos directory (that needs to be absolute not relative)
out_dir.mkdir(parents=True, exist_ok=True)
videos_dir = out_dir / "videos"
videos_dir.parent.mkdir(parents=True, exist_ok=True)
videos_dir.symlink_to((raw_dir / "videos").absolute())
# sanity check the video paths are well formated
@@ -156,16 +154,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
else:
raise ValueError(key)
# Get the episode index containing for each unique episode index
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
from_ = first_ep_index_df["start_index"].tolist()
to_ = from_[1:] + [len(df)]
episode_data_index = {
"from": from_,
"to": to_,
}
return data_dict, episode_data_index
return data_dict
def to_hf_dataset(data_dict, video) -> Dataset:
@@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset:
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
init_logging()
if debug:
logging.warning("debug=True not implemented. Falling back to debug=False.")
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check
check_format(raw_dir)
@@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
if not video:
raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
hf_dataset = to_hf_dataset(data_df, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,

View File

@@ -27,6 +27,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -53,7 +54,7 @@ def check_format(raw_dir):
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def load_from_raw(raw_dir, out_dir, fps, video, debug):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
num_episodes = zarr_data.meta["episode_ends"].shape[0]
assert len(
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
@@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
states = torch.from_numpy(zarr_data["state"])
actions = torch.from_numpy(zarr_data["action"])
ep_dicts = []
episode_data_index = {"from": [], "to": []}
# load data indices from which each episode starts and ends
from_ids, to_ids = [], []
from_idx = 0
for to_idx in zarr_data.meta["episode_ends"]:
from_ids.append(from_idx)
to_ids.append(to_idx)
from_idx = to_idx
id_from = 0
for ep_idx in tqdm.tqdm(range(num_episodes)):
id_to = zarr_data.meta["episode_ends"][ep_idx]
num_frames = id_to - id_from
num_episodes = len(from_ids)
ep_dicts = []
ep_ids = episodes if episodes else range(num_episodes)
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
from_idx = from_ids[selected_ep_idx]
to_idx = to_ids[selected_ep_idx]
num_frames = to_idx - from_idx
# sanity check
assert (episode_ids[id_from:id_to] == ep_idx).all()
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
# get image
image = imgs[id_from:id_to]
image = imgs[from_idx:to_idx]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
# get state
state = states[id_from:id_to]
state = states[from_idx:to_idx]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
@@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
@@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = agent_pos
ep_dict["action"] = actions[id_from:id_to]
ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
@@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video):
@@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video):
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check
check_format(raw_dir)
if fps is None:
fps = 10
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,

View File

@@ -19,7 +19,6 @@ import logging
import shutil
from pathlib import Path
import numpy as np
import torch
import tqdm
import zarr
@@ -29,6 +28,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool:
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
from numba import jit
@jit(nopython=True)
def _get_episode_idxs(episode_ends):
result = np.zeros((episode_ends[-1],), dtype=np.int64)
start_idx = 0
for episode_number, end_idx in enumerate(episode_ends):
result[start_idx:end_idx] = episode_number
start_idx = end_idx
return result
return _get_episode_idxs(episode_ends)
def load_from_raw(raw_dir, out_dir, fps, video, debug):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
zarr_path = raw_dir / "cup_in_the_wild.zarr"
zarr_data = zarr.open(zarr_path, mode="r")
@@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
episode_ends = zarr_data["meta/episode_ends"][:]
num_episodes = episode_ends.shape[0]
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
# We convert it in torch tensor later because the jit function does not support torch tensors
episode_ends = torch.from_numpy(episode_ends)
# load data indices from which each episode starts and ends
from_ids, to_ids = [], []
from_idx = 0
for to_idx in episode_ends:
from_ids.append(from_idx)
to_ids.append(to_idx)
from_idx = to_idx
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_idx in tqdm.tqdm(range(num_episodes)):
id_to = episode_ends[ep_idx]
num_frames = id_to - id_from
# sanity heck
assert (episode_ids[id_from:id_to] == ep_idx).all()
ep_ids = episodes if episodes else range(num_episodes)
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
from_idx = from_ids[selected_ep_idx]
to_idx = to_ids[selected_ep_idx]
num_frames = to_idx - from_idx
# TODO(rcadene): save temporary images of the episode?
state = states[id_from:id_to]
state = states[from_idx:to_idx]
ep_dict = {}
# load 57MB of images in RAM (400x224x224x3 uint8)
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
@@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
ep_dict["end_pose"] = end_pose[id_from:id_to]
ep_dict["start_pos"] = start_pos[id_from:id_to]
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts)
total_frames = id_from
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict, episode_data_index
return data_dict
def to_hf_dataset(data_dict, video):
@@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video):
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check
check_format(raw_dir)
@@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
)
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,

View File

@@ -27,6 +27,7 @@ from PIL import Image as PILImage
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -54,37 +55,42 @@ def check_format(raw_dir):
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
def load_from_raw(raw_dir, out_dir, fps, video, debug):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
pkl_path = raw_dir / "buffer.pkl"
with open(pkl_path, "rb") as f:
pkl_data = pickle.load(f)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
id_to = 0
ep_idx = 0
total_frames = pkl_data["actions"].shape[0]
for i in tqdm.tqdm(range(total_frames)):
id_to += 1
if not pkl_data["dones"][i]:
# load data indices from which each episode starts and ends
from_ids, to_ids = [], []
from_idx, to_idx = 0, 0
for done in pkl_data["dones"]:
to_idx += 1
if not done:
continue
from_ids.append(from_idx)
to_ids.append(to_idx)
from_idx = to_idx
num_frames = id_to - id_from
num_episodes = len(from_ids)
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to])
ep_dicts = []
ep_ids = episodes if episodes else range(num_episodes)
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
from_idx = from_ids[selected_ep_idx]
to_idx = to_ids[selected_ep_idx]
num_frames = to_idx - from_idx
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
image = einops.rearrange(image, "b c h w -> b h w c")
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to])
action = torch.tensor(pkl_data["actions"][id_from:id_to])
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to])
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to])
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to])
next_done = torch.tensor(pkl_data["dones"][id_from:id_to])
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
ep_dict = {}
@@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
@@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict["next.done"] = next_done
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from = id_to
ep_idx += 1
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video):
@@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video):
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
# sanity check
check_format(raw_dir)
if fps is None:
fps = 15
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,