Refactor push_dataset_to_hub
This commit is contained in:
@@ -4,20 +4,21 @@ useless dependencies when using datasets.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
def download_raw(root, dataset_id) -> Path:
|
def download_raw(raw_dir, dataset_id) -> Path:
|
||||||
if "pusht" in dataset_id:
|
if "pusht" in dataset_id:
|
||||||
return download_pusht(root=root, dataset_id=dataset_id)
|
return download_pusht(raw_dir)
|
||||||
elif "xarm" in dataset_id:
|
elif "xarm" in dataset_id:
|
||||||
return download_xarm(root=root, dataset_id=dataset_id)
|
return download_xarm(raw_dir)
|
||||||
elif "aloha" in dataset_id:
|
elif "aloha" in dataset_id:
|
||||||
return download_aloha(root=root, dataset_id=dataset_id)
|
return download_aloha(raw_dir, dataset_id)
|
||||||
elif "umi" in dataset_id:
|
elif "umi" in dataset_id:
|
||||||
return download_umi(root=root, dataset_id=dataset_id)
|
return download_umi(raw_dir)
|
||||||
else:
|
else:
|
||||||
raise ValueError(dataset_id)
|
raise ValueError(dataset_id)
|
||||||
|
|
||||||
@@ -50,42 +51,37 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def download_pusht(root: str, dataset_id: str = "pusht", fps: int = 10) -> Path:
|
def download_pusht(raw_dir: str):
|
||||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||||
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
|
|
||||||
|
|
||||||
root = Path(root)
|
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
|
||||||
zarr_path: Path = (raw_dir / pusht_zarr).resolve()
|
|
||||||
if not zarr_path.is_dir():
|
if not zarr_path.is_dir():
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
download_and_extract_zip(pusht_url, raw_dir)
|
download_and_extract_zip(pusht_url, zarr_path)
|
||||||
return zarr_path
|
|
||||||
|
|
||||||
|
|
||||||
def download_xarm(root: str, dataset_id: str, fps: int = 15) -> Path:
|
def download_xarm(raw_dir: str) -> Path:
|
||||||
root = Path(root)
|
"""Download all xarm datasets at once"""
|
||||||
raw_dir: Path = root / "xarm_datasets_raw"
|
import zipfile
|
||||||
if not raw_dir.exists():
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import gdown
|
import gdown
|
||||||
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||||
zip_path = raw_dir / "data.zip"
|
zip_path = raw_dir / "data.zip"
|
||||||
gdown.download(url, str(zip_path), quiet=False)
|
gdown.download(url, str(zip_path), quiet=False)
|
||||||
print("Extracting...")
|
print("Extracting...")
|
||||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||||
for member in zip_f.namelist():
|
for path in zip_f.namelist():
|
||||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
if path.startswith("data/xarm") and path.endswith(".pkl"):
|
||||||
print(member)
|
zip_f.extract(member=path)
|
||||||
zip_f.extract(member=member)
|
# move to corresponding raw directory
|
||||||
zip_path.unlink()
|
member_dir = path.replace("/buffer.pkl", "")
|
||||||
|
member_raw_dir = path.replace("/buffer.pkl", "_raw")
|
||||||
dataset_path: Path = root / f"{dataset_id}"
|
shutil.move(path, member_raw_dir)
|
||||||
return dataset_path
|
shutil.rmtree(member_dir)
|
||||||
|
zip_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def download_aloha(root: str, dataset_id: str) -> Path:
|
def download_aloha(root: str, dataset_id: str) -> Path:
|
||||||
@@ -148,13 +144,9 @@ def download_aloha(root: str, dataset_id: str) -> Path:
|
|||||||
return raw_dir
|
return raw_dir
|
||||||
|
|
||||||
|
|
||||||
def download_umi(root: str, dataset_id: str) -> Path:
|
def download_umi(raw_dir: Path) -> Path:
|
||||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||||
cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr")
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
|
|
||||||
root = Path(root)
|
|
||||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
|
||||||
zarr_path: Path = (raw_dir / cup_in_the_wild_zarr).resolve()
|
|
||||||
if not zarr_path.is_dir():
|
if not zarr_path.is_dir():
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||||
@@ -162,7 +154,7 @@ def download_umi(root: str, dataset_id: str) -> Path:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
root = "data"
|
data_dir = Path("data")
|
||||||
dataset_ids = [
|
dataset_ids = [
|
||||||
"pusht",
|
"pusht",
|
||||||
"xarm_lift_medium",
|
"xarm_lift_medium",
|
||||||
@@ -176,4 +168,5 @@ if __name__ == "__main__":
|
|||||||
"umi_cup_in_the_wild",
|
"umi_cup_in_the_wild",
|
||||||
]
|
]
|
||||||
for dataset_id in dataset_ids:
|
for dataset_id in dataset_ids:
|
||||||
download_raw(root=root, dataset_id=dataset_id)
|
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||||
|
download_raw(raw_dir, dataset_id)
|
||||||
|
|||||||
171
lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
Normal file
171
lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_raw_format(raw_dir) -> bool:
|
||||||
|
cameras = ["top"]
|
||||||
|
|
||||||
|
hdf5_files: list[Path] = list(raw_dir.glob("episode_*.hdf5"))
|
||||||
|
if len(hdf5_files) == 0:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1)))
|
||||||
|
except AttributeError:
|
||||||
|
# All file names must contain a numerical identifier matching 'episode_(\\d+).hdf5
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc.
|
||||||
|
# If not, return False
|
||||||
|
previous_number = None
|
||||||
|
for file in hdf5_files:
|
||||||
|
current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1))
|
||||||
|
if previous_number is not None and current_number - previous_number != 1:
|
||||||
|
return False
|
||||||
|
previous_number = current_number
|
||||||
|
|
||||||
|
for file in hdf5_files:
|
||||||
|
try:
|
||||||
|
with h5py.File(file, "r") as file:
|
||||||
|
# Check for the expected datasets within the HDF5 file
|
||||||
|
required_datasets = ["/action", "/observations/qpos"]
|
||||||
|
# Add camera-specific image datasets to the required datasets
|
||||||
|
camera_datasets = [f"/observations/images/{cam}" for cam in cameras]
|
||||||
|
required_datasets.extend(camera_datasets)
|
||||||
|
|
||||||
|
if not all(dataset in file for dataset in required_datasets):
|
||||||
|
return False
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||||
|
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1)))
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
|
||||||
|
for ep_path in tqdm.tqdm(hdf5_files):
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
ep_idx = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||||
|
num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
|
||||||
|
cameras = list(ep["/observations/images"].keys())
|
||||||
|
for cam in cameras:
|
||||||
|
img_key = f"observation.images.{cam}"
|
||||||
|
imgs_array = ep[f"/observations/images/{cam}"][:] # b h w c
|
||||||
|
if video:
|
||||||
|
# save png images in temporary directory
|
||||||
|
tmp_imgs_dir = out_dir / "tmp_images"
|
||||||
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
|
# encode images to a mp4 video
|
||||||
|
video_path = out_dir / "videos" / f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
|
# clean temporary images directory
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
# store the episode idx
|
||||||
|
ep_dict[img_key] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||||
|
else:
|
||||||
|
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
|
ep_dict["observation.state"] = state
|
||||||
|
ep_dict["action"] = action
|
||||||
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||||
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
ep_dict["next.done"] = done
|
||||||
|
# TODO(rcadene): compute reward and success
|
||||||
|
# ep_dict[""next.reward"] = reward
|
||||||
|
# ep_dict[""next.success"] = success
|
||||||
|
|
||||||
|
assert isinstance(ep_idx, int)
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
# process first episode only
|
||||||
|
if debug:
|
||||||
|
break
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
|
||||||
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
image_keys = [key for key in data_dict if "observation.images." in key]
|
||||||
|
for image_key in image_keys:
|
||||||
|
if video:
|
||||||
|
features[image_key] = Value(dtype="int64", id="video")
|
||||||
|
else:
|
||||||
|
features[image_key] = Image()
|
||||||
|
|
||||||
|
features["observation.state"] = Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["action"] = Sequence(
|
||||||
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
|
features["next.done"] = Value(dtype="bool", id=None)
|
||||||
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
|
# TODO(rcadene): add reward and success
|
||||||
|
# features["next.reward"] = Value(dtype="float32", id=None)
|
||||||
|
# features["next.success"] = Value(dtype="bool", id=None)
|
||||||
|
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||||
|
assert is_valid_raw_format(raw_dir), f"{raw_dir} does not match the expected format."
|
||||||
|
|
||||||
|
if fps is None:
|
||||||
|
fps = 50
|
||||||
|
|
||||||
|
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||||
|
hf_dataset = to_hf_dataset(data_dir, video)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": fps,
|
||||||
|
"video": video,
|
||||||
|
}
|
||||||
|
return hf_dataset, episode_data_index, info
|
||||||
@@ -1,199 +0,0 @@
|
|||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import h5py
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
hf_transform_to_torch,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AlohaProcessor:
|
|
||||||
"""
|
|
||||||
Process HDF5 files formatted like in: https://github.com/tonyzhaozh/act
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
folder_path (Path): Path to the directory containing HDF5 files.
|
|
||||||
cameras (list[str]): List of camera identifiers to check in the files.
|
|
||||||
fps (int): Frames per second used in timestamp calculations.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
is_valid() -> bool:
|
|
||||||
Validates if each HDF5 file within the folder contains all required datasets.
|
|
||||||
preprocess() -> dict:
|
|
||||||
Processes the files and returns structured data suitable for further analysis.
|
|
||||||
to_hf_dataset(data_dict: dict) -> Dataset:
|
|
||||||
Converts processed data into a Hugging Face Dataset object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, folder_path: Path, cameras: list[str] | None = None, fps: int | None = None):
|
|
||||||
"""
|
|
||||||
Initializes the AlohaProcessor with a specified directory path containing HDF5 files,
|
|
||||||
an optional list of cameras, and a frame rate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_path (Path): The directory path where HDF5 files are stored.
|
|
||||||
cameras (list[str] | None): Optional list of cameras to validate within the files. Defaults to ['top'] if None.
|
|
||||||
fps (int): Frame rate for the datasets, used in time calculations. Default is 50.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> processor = AlohaProcessor(Path("path_to_hdf5_directory"), ["camera1", "camera2"])
|
|
||||||
>>> processor.is_valid()
|
|
||||||
True
|
|
||||||
"""
|
|
||||||
self.folder_path = folder_path
|
|
||||||
if cameras is None:
|
|
||||||
cameras = ["top"]
|
|
||||||
self.cameras = cameras
|
|
||||||
if fps is None:
|
|
||||||
fps = 50
|
|
||||||
self._fps = fps
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fps(self) -> int:
|
|
||||||
return self._fps
|
|
||||||
|
|
||||||
def is_valid(self) -> bool:
|
|
||||||
"""
|
|
||||||
Validates the HDF5 files in the specified folder to ensure they contain the required datasets
|
|
||||||
for actions, positions, and images for each specified camera.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if all files are valid HDF5 files with all required datasets, False otherwise.
|
|
||||||
"""
|
|
||||||
hdf5_files: list[Path] = list(self.folder_path.glob("episode_*.hdf5"))
|
|
||||||
if len(hdf5_files) == 0:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
hdf5_files = sorted(
|
|
||||||
hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1))
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
# All file names must contain a numerical identifier matching 'episode_(\\d+).hdf5
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc.
|
|
||||||
# If not, return False
|
|
||||||
previous_number = None
|
|
||||||
for file in hdf5_files:
|
|
||||||
current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1))
|
|
||||||
if previous_number is not None and current_number - previous_number != 1:
|
|
||||||
return False
|
|
||||||
previous_number = current_number
|
|
||||||
|
|
||||||
for file in hdf5_files:
|
|
||||||
try:
|
|
||||||
with h5py.File(file, "r") as file:
|
|
||||||
# Check for the expected datasets within the HDF5 file
|
|
||||||
required_datasets = ["/action", "/observations/qpos"]
|
|
||||||
# Add camera-specific image datasets to the required datasets
|
|
||||||
camera_datasets = [f"/observations/images/{cam}" for cam in self.cameras]
|
|
||||||
required_datasets.extend(camera_datasets)
|
|
||||||
|
|
||||||
if not all(dataset in file for dataset in required_datasets):
|
|
||||||
return False
|
|
||||||
except OSError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def preprocess(self):
|
|
||||||
"""
|
|
||||||
Collects episode data from the HDF5 file and returns it as an AlohaStep named tuple.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AlohaStep: Named tuple containing episode data.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the file is not valid.
|
|
||||||
"""
|
|
||||||
if not self.is_valid():
|
|
||||||
raise ValueError("The HDF5 file is invalid or does not contain the required datasets.")
|
|
||||||
|
|
||||||
hdf5_files = list(self.folder_path.glob("*.hdf5"))
|
|
||||||
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1)))
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
|
||||||
|
|
||||||
for ep_path in tqdm.tqdm(hdf5_files):
|
|
||||||
with h5py.File(ep_path, "r") as ep:
|
|
||||||
ep_id = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
|
||||||
num_frames = ep["/action"].shape[0]
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done[-1] = True
|
|
||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
|
|
||||||
for cam in self.cameras:
|
|
||||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
|
||||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
|
||||||
|
|
||||||
ep_dict.update(
|
|
||||||
{
|
|
||||||
"observation.state": state,
|
|
||||||
"action": action,
|
|
||||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
|
||||||
# TODO(rcadene): compute reward and success
|
|
||||||
# "next.reward": reward,
|
|
||||||
"next.done": done,
|
|
||||||
# "next.success": success,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(ep_id, int)
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
return data_dict, episode_data_index
|
|
||||||
|
|
||||||
def to_hf_dataset(self, data_dict) -> Dataset:
|
|
||||||
"""
|
|
||||||
Converts a dictionary of data into a Hugging Face Dataset object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_dict (dict): A dictionary containing the data to be converted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dataset: The converted Hugging Face Dataset object.
|
|
||||||
"""
|
|
||||||
image_features = {f"observation.images.{cam}": Image() for cam in self.cameras}
|
|
||||||
features = {
|
|
||||||
"observation.state": Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
# "next.reward": Value(dtype="float32", id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
# "next.success": Value(dtype="bool", id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
update_features = {**image_features, **features}
|
|
||||||
features = Features(update_features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
pass
|
|
||||||
215
lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
Normal file
215
lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import zarr
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_raw_format(raw_dir) -> bool:
|
||||||
|
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||||
|
try:
|
||||||
|
zarr_data = zarr.open(zarr_path, mode="r")
|
||||||
|
except Exception:
|
||||||
|
# TODO (azouitine): Handle the exception properly
|
||||||
|
return False
|
||||||
|
required_datasets = {
|
||||||
|
"data/action",
|
||||||
|
"data/img",
|
||||||
|
"data/keypoint",
|
||||||
|
"data/n_contacts",
|
||||||
|
"data/state",
|
||||||
|
"meta/episode_ends",
|
||||||
|
}
|
||||||
|
for dataset in required_datasets:
|
||||||
|
if dataset not in zarr_data:
|
||||||
|
return False
|
||||||
|
nb_frames = zarr_data["data/img"].shape[0]
|
||||||
|
|
||||||
|
required_datasets.remove("meta/episode_ends")
|
||||||
|
|
||||||
|
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
try:
|
||||||
|
import pymunk
|
||||||
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||||
|
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||||
|
raise e
|
||||||
|
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
|
||||||
|
success_threshold = 0.95 # 95% coverage,
|
||||||
|
|
||||||
|
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||||
|
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||||
|
|
||||||
|
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||||
|
num_episodes = zarr_data.meta["episode_ends"].shape[0]
|
||||||
|
assert len(
|
||||||
|
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||||
|
), "Some data type dont have the same number of total frames."
|
||||||
|
|
||||||
|
# TODO(rcadene): verify that goal pose is expected to be fixed
|
||||||
|
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
|
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||||
|
|
||||||
|
imgs = torch.from_numpy(zarr_data["img"]) # b h w c
|
||||||
|
states = torch.from_numpy(zarr_data["state"])
|
||||||
|
actions = torch.from_numpy(zarr_data["action"])
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||||
|
id_to = zarr_data.meta["episode_ends"][ep_idx]
|
||||||
|
num_frames = id_to - id_from
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||||
|
|
||||||
|
# get image
|
||||||
|
image = imgs[id_from:id_to]
|
||||||
|
assert image.min() >= 0.0
|
||||||
|
assert image.max() <= 255.0
|
||||||
|
image = image.type(torch.uint8)
|
||||||
|
|
||||||
|
# get state
|
||||||
|
state = states[id_from:id_to]
|
||||||
|
agent_pos = state[:, :2]
|
||||||
|
block_pos = state[:, 2:4]
|
||||||
|
block_angle = state[:, 4]
|
||||||
|
|
||||||
|
# get reward, success, done
|
||||||
|
reward = torch.zeros(num_frames)
|
||||||
|
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
for i in range(num_frames):
|
||||||
|
space = pymunk.Space()
|
||||||
|
space.gravity = 0, 0
|
||||||
|
space.damping = 0
|
||||||
|
|
||||||
|
# Add walls.
|
||||||
|
walls = [
|
||||||
|
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||||
|
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||||
|
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||||
|
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||||
|
]
|
||||||
|
space.add(*walls)
|
||||||
|
|
||||||
|
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||||
|
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||||
|
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||||
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
|
goal_area = goal_geom.area
|
||||||
|
coverage = intersection_area / goal_area
|
||||||
|
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||||
|
success[i] = coverage > success_threshold
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
|
||||||
|
imgs_array = [x.numpy() for x in image]
|
||||||
|
if video:
|
||||||
|
# save png images in temporary directory
|
||||||
|
tmp_imgs_dir = out_dir / "tmp_images"
|
||||||
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
|
# encode images to a mp4 video
|
||||||
|
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
|
# clean temporary images directory
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
# store the episode index
|
||||||
|
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||||
|
else:
|
||||||
|
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
|
ep_dict["observation.state"] = agent_pos
|
||||||
|
ep_dict["action"] = actions[id_from:id_to]
|
||||||
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||||
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
# ep_dict["next.observation.image"] = image[1:],
|
||||||
|
# ep_dict["next.observation.state"] = agent_pos[1:],
|
||||||
|
# TODO(rcadene)] = verify that reward and done are aligned with image and agent_pos
|
||||||
|
ep_dict["next.reward"] = torch.cat([reward[1:], reward[[-1]]])
|
||||||
|
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
|
||||||
|
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
# process first episode only
|
||||||
|
if debug:
|
||||||
|
break
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
|
||||||
|
def to_hf_dataset(data_dict, video):
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
if video:
|
||||||
|
features["observation.image"] = Value(dtype="int64", id="video")
|
||||||
|
else:
|
||||||
|
features["observation.image"] = Image()
|
||||||
|
|
||||||
|
features["observation.state"] = Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["action"] = Sequence(
|
||||||
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
|
features["next.reward"] = Value(dtype="float32", id=None)
|
||||||
|
features["next.done"] = Value(dtype="bool", id=None)
|
||||||
|
features["next.success"] = Value(dtype="bool", id=None)
|
||||||
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
|
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||||
|
assert is_valid_raw_format(raw_dir), f"{raw_dir} does not match the expected format."
|
||||||
|
|
||||||
|
if fps is None:
|
||||||
|
fps = 10
|
||||||
|
|
||||||
|
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||||
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": fps,
|
||||||
|
"video": video,
|
||||||
|
}
|
||||||
|
return hf_dataset, episode_data_index, info
|
||||||
203
lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
Normal file
203
lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import zarr
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_raw_format(raw_dir) -> bool:
|
||||||
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
|
|
||||||
|
try:
|
||||||
|
zarr_data = zarr.open(zarr_path, mode="r")
|
||||||
|
except Exception:
|
||||||
|
# TODO (azouitine): Handle the exception properly
|
||||||
|
return False
|
||||||
|
required_datasets = {
|
||||||
|
"data/robot0_demo_end_pose",
|
||||||
|
"data/robot0_demo_start_pose",
|
||||||
|
"data/robot0_eef_pos",
|
||||||
|
"data/robot0_eef_rot_axis_angle",
|
||||||
|
"data/robot0_gripper_width",
|
||||||
|
"meta/episode_ends",
|
||||||
|
"data/camera0_rgb",
|
||||||
|
}
|
||||||
|
for dataset in required_datasets:
|
||||||
|
if dataset not in zarr_data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
register_codecs()
|
||||||
|
|
||||||
|
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||||
|
|
||||||
|
required_datasets.remove("meta/episode_ends")
|
||||||
|
|
||||||
|
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||||
|
zarr_data = zarr.open(zarr_path, mode="r")
|
||||||
|
|
||||||
|
# We process the image data separately because it is too large to fit in memory
|
||||||
|
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||||
|
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||||
|
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||||
|
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||||
|
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||||
|
|
||||||
|
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||||
|
states = torch.cat([states_pos, gripper_width], dim=1)
|
||||||
|
|
||||||
|
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||||
|
num_episodes = episode_ends.shape[0]
|
||||||
|
|
||||||
|
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
||||||
|
|
||||||
|
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||||
|
episode_ends = torch.from_numpy(episode_ends)
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||||
|
id_to = episode_ends[ep_idx]
|
||||||
|
num_frames = id_to - id_from
|
||||||
|
|
||||||
|
# sanity heck
|
||||||
|
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||||
|
|
||||||
|
# TODO(rcadene): save temporary images of the episode?
|
||||||
|
|
||||||
|
state = states[id_from:id_to]
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
|
||||||
|
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||||
|
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
|
||||||
|
if video:
|
||||||
|
# save png images in temporary directory
|
||||||
|
tmp_imgs_dir = out_dir / "tmp_images"
|
||||||
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
|
# encode images to a mp4 video
|
||||||
|
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
|
# clean temporary images directory
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
# store the episode index
|
||||||
|
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||||
|
else:
|
||||||
|
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
|
ep_dict["observation.state"] = state
|
||||||
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||||
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
|
||||||
|
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
|
||||||
|
ep_dict["end_pose"] = end_pose[id_from:id_to]
|
||||||
|
ep_dict["start_pos"] = start_pos[id_from:id_to]
|
||||||
|
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
# process first episode only
|
||||||
|
if debug:
|
||||||
|
break
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
|
total_frames = id_from
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
|
||||||
|
def to_hf_dataset(data_dict, video):
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
if video:
|
||||||
|
features["observation.image"] = Value(dtype="int64", id="video")
|
||||||
|
else:
|
||||||
|
features["observation.image"] = Image()
|
||||||
|
|
||||||
|
features["observation.state"] = Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
|
features["episode_data_index_from"] = Value(dtype="int64", id=None)
|
||||||
|
features["episode_data_index_to"] = Value(dtype="int64", id=None)
|
||||||
|
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
||||||
|
# at the beginning and the end of the episode.
|
||||||
|
# `gripper_width` indicates the distance between the grippers, and this value is included
|
||||||
|
# in the state vector, which comprises the concatenation of the end-effector position
|
||||||
|
# and gripper width.
|
||||||
|
features["end_pose"] = Sequence(
|
||||||
|
length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["start_pos"] = Sequence(
|
||||||
|
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["gripper_width"] = Sequence(
|
||||||
|
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||||
|
assert is_valid_raw_format(raw_dir), f"{raw_dir} does not match the expected format."
|
||||||
|
|
||||||
|
if fps is None:
|
||||||
|
# For umi cup in the wild: https://arxiv.org/pdf/2402.10329#table.caption.16
|
||||||
|
fps = 10
|
||||||
|
|
||||||
|
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||||
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": fps,
|
||||||
|
"video": video,
|
||||||
|
}
|
||||||
|
return hf_dataset, episode_data_index, info
|
||||||
|
|
||||||
|
|
||||||
|
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
||||||
|
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||||
|
from numba import jit
|
||||||
|
|
||||||
|
@jit(nopython=True)
|
||||||
|
def _get_episode_idxs(episode_ends):
|
||||||
|
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||||
|
start_idx = 0
|
||||||
|
for episode_number, end_idx in enumerate(episode_ends):
|
||||||
|
result[start_idx:end_idx] = episode_number
|
||||||
|
start_idx = end_idx
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _get_episode_idxs(episode_ends)
|
||||||
@@ -1,3 +1,8 @@
|
|||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -18,3 +23,16 @@ def concatenate_episodes(ep_dicts):
|
|||||||
total_frames = data_dict["frame_index"].shape[0]
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def save_image(img_array, i, out_dir):
|
||||||
|
img = PIL.Image.fromarray(img_array)
|
||||||
|
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
|
||||||
|
|
||||||
|
num_images = len(imgs_array)
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||||
|
|||||||
173
lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
Normal file
173
lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_raw_format(raw_dir):
|
||||||
|
keys = {"actions", "rewards", "dones"}
|
||||||
|
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||||
|
|
||||||
|
xarm_files = list(raw_dir.glob("*.pkl"))
|
||||||
|
if len(xarm_files) != 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(xarm_files[0], "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not isinstance(dataset_dict, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not all(k in dataset_dict for k in keys):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for consistent lengths in nested keys
|
||||||
|
try:
|
||||||
|
expected_len = len(dataset_dict["actions"])
|
||||||
|
if any(len(dataset_dict[key]) != expected_len for key in keys if key in dataset_dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for key, subkeys in nested_keys.items():
|
||||||
|
nested_dict = dataset_dict.get(key, {})
|
||||||
|
if any(len(nested_dict[subkey]) != expected_len for subkey in subkeys if subkey in nested_dict):
|
||||||
|
return False
|
||||||
|
except KeyError: # If any expected key or subkey is missing
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True # All checks passed
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||||
|
xarm_files = list(raw_dir.glob("*.pkl"))
|
||||||
|
|
||||||
|
with open(xarm_files[0], "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
id_to = 0
|
||||||
|
ep_idx = 0
|
||||||
|
total_frames = dataset_dict["actions"].shape[0]
|
||||||
|
for i in tqdm.tqdm(range(total_frames)):
|
||||||
|
id_to += 1
|
||||||
|
|
||||||
|
if not dataset_dict["dones"][i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_frames = id_to - id_from
|
||||||
|
|
||||||
|
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
|
||||||
|
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||||
|
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
|
||||||
|
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
|
||||||
|
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||||
|
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
|
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
|
||||||
|
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
|
||||||
|
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
|
||||||
|
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
|
||||||
|
imgs_array = [x.numpy() for x in image]
|
||||||
|
if video:
|
||||||
|
# save png images in temporary directory
|
||||||
|
tmp_imgs_dir = out_dir / "tmp_images"
|
||||||
|
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||||
|
|
||||||
|
# encode images to a mp4 video
|
||||||
|
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
|
||||||
|
# clean temporary images directory
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
|
# store the episode index
|
||||||
|
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||||
|
else:
|
||||||
|
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
|
ep_dict["observation.state"] = state
|
||||||
|
ep_dict["action"] = action
|
||||||
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||||
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
# ep_dict["next.observation.image"] = next_image
|
||||||
|
# ep_dict["next.observation.state"] = next_state
|
||||||
|
ep_dict["next.reward"] = next_reward
|
||||||
|
ep_dict["next.done"] = next_done
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from = id_to
|
||||||
|
ep_idx += 1
|
||||||
|
|
||||||
|
# process first episode only
|
||||||
|
if debug:
|
||||||
|
break
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
|
||||||
|
def to_hf_dataset(data_dict, video):
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
if video:
|
||||||
|
features["observation.image"] = Value(dtype="int64", id="video")
|
||||||
|
else:
|
||||||
|
features["observation.image"] = Image()
|
||||||
|
|
||||||
|
features["observation.state"] = Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["action"] = Sequence(
|
||||||
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
)
|
||||||
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
|
features["next.reward"] = Value(dtype="float32", id=None)
|
||||||
|
features["next.done"] = Value(dtype="bool", id=None)
|
||||||
|
features["index"] = Value(dtype="int64", id=None)
|
||||||
|
# TODO(rcadene): add success
|
||||||
|
# features["next.success"] = Value(dtype='bool', id=None)
|
||||||
|
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||||
|
assert is_valid_raw_format(raw_dir), f"{raw_dir} does not match the expected format."
|
||||||
|
|
||||||
|
if fps is None:
|
||||||
|
fps = 15
|
||||||
|
|
||||||
|
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||||
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": fps,
|
||||||
|
"video": video,
|
||||||
|
}
|
||||||
|
return hf_dataset, episode_data_index, info
|
||||||
@@ -1,295 +1,215 @@
|
|||||||
|
"""
|
||||||
|
Use this script to convert your dataset into our dataset format and upload it to the Hugging Face hub,
|
||||||
|
or store it locally. Our dataset format is lightweight, fast to load from, and does not require any
|
||||||
|
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
|
--data-dir data \
|
||||||
|
--dataset-id pusht \
|
||||||
|
--raw-format pusht_zarr \
|
||||||
|
--community-id lerobot \
|
||||||
|
--revision v1.2 \
|
||||||
|
--dry-run 1 \
|
||||||
|
--save-to-disk 1 \
|
||||||
|
--save-tests-to-disk 0 \
|
||||||
|
--debug 1
|
||||||
|
|
||||||
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
|
--data-dir data \
|
||||||
|
--dataset-id xarm_lift_medium \
|
||||||
|
--raw-format xarm_pkl \
|
||||||
|
--community-id lerobot \
|
||||||
|
--revision v1.2 \
|
||||||
|
--dry-run 1 \
|
||||||
|
--save-to-disk 1 \
|
||||||
|
--save-tests-to-disk 0 \
|
||||||
|
--debug 1
|
||||||
|
|
||||||
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
|
--data-dir data \
|
||||||
|
--dataset-id aloha_sim_insertion_scripted \
|
||||||
|
--raw-format aloha_hdf5 \
|
||||||
|
--community-id lerobot \
|
||||||
|
--revision v1.2 \
|
||||||
|
--dry-run 1 \
|
||||||
|
--save-to-disk 1 \
|
||||||
|
--save-tests-to-disk 0 \
|
||||||
|
--debug 1
|
||||||
|
|
||||||
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
|
--data-dir data \
|
||||||
|
--dataset-id umi_cup_in_the_wild \
|
||||||
|
--raw-format umi_zarr \
|
||||||
|
--community-id lerobot \
|
||||||
|
--revision v1.2 \
|
||||||
|
--dry-run 1 \
|
||||||
|
--save-to-disk 1 \
|
||||||
|
--save-tests-to-disk 0 \
|
||||||
|
--debug 1
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Protocol
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_processor import (
|
|
||||||
AlohaProcessor,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_processor import PushTProcessor
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.umi_processor import UmiProcessor
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_processor import XarmProcessor
|
|
||||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
||||||
|
|
||||||
|
|
||||||
def push_lerobot_dataset_to_hub(
|
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||||
hf_dataset: Dataset,
|
if raw_format == "pusht_zarr":
|
||||||
episode_data_index: dict[str, list[int]],
|
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||||
info: dict[str, Any],
|
elif raw_format == "umi_zarr":
|
||||||
stats: dict[str, dict[str, torch.Tensor]],
|
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||||
root: Path,
|
elif raw_format == "aloha_hdf5":
|
||||||
revision: str,
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||||
dataset_id: str,
|
elif raw_format == "xarm_pkl":
|
||||||
community_id: str = "lerobot",
|
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||||
dry_run: bool = False,
|
else:
|
||||||
) -> None:
|
raise ValueError(raw_format)
|
||||||
"""
|
|
||||||
Pushes a dataset to the Hugging Face Hub.
|
|
||||||
|
|
||||||
Args:
|
return from_raw_to_lerobot_format
|
||||||
hf_dataset (Dataset): The dataset to be pushed.
|
|
||||||
episode_data_index (dict[str, list[int]]): The index of episode data.
|
|
||||||
info (dict[str, Any]): Information about the dataset, eg. fps.
|
|
||||||
stats (dict[str, dict[str, torch.Tensor]]): Statistics of the dataset.
|
|
||||||
root (Path): The root directory of the dataset.
|
|
||||||
revision (str): The revision of the dataset.
|
|
||||||
dataset_id (str): The ID of the dataset.
|
|
||||||
community_id (str, optional): The ID of the community or the user where the
|
|
||||||
dataset will be stored. Defaults to "lerobot".
|
|
||||||
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
|
||||||
"""
|
|
||||||
if not dry_run:
|
|
||||||
# push to main to indicate latest version
|
|
||||||
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True)
|
|
||||||
|
|
||||||
# push to version branch
|
|
||||||
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True, revision=revision)
|
|
||||||
|
|
||||||
# create and store meta_data
|
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||||
meta_data_dir = root / community_id / dataset_id / "meta_data"
|
|
||||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# info
|
# save info
|
||||||
info_path = meta_data_dir / "info.json"
|
info_path = meta_data_dir / "info.json"
|
||||||
|
|
||||||
with open(str(info_path), "w") as f:
|
with open(str(info_path), "w") as f:
|
||||||
json.dump(info, f, indent=4)
|
json.dump(info, f, indent=4)
|
||||||
# stats
|
|
||||||
|
# save stats
|
||||||
stats_path = meta_data_dir / "stats.safetensors"
|
stats_path = meta_data_dir / "stats.safetensors"
|
||||||
save_file(flatten_dict(stats), stats_path)
|
save_file(flatten_dict(stats), stats_path)
|
||||||
|
|
||||||
# episode_data_index
|
# save episode_data_index
|
||||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||||
save_file(episode_data_index, ep_data_idx_path)
|
save_file(episode_data_index, ep_data_idx_path)
|
||||||
|
|
||||||
if not dry_run:
|
|
||||||
api = HfApi()
|
|
||||||
|
|
||||||
|
def push_meta_data_to_hub(meta_data_dir, repo_id, revision):
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
def upload(filename, revision):
|
||||||
api.upload_file(
|
api.upload_file(
|
||||||
path_or_fileobj=info_path,
|
path_or_fileobj=meta_data_dir / filename,
|
||||||
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
path_in_repo=f"meta_data/{filename}",
|
||||||
repo_id=f"{community_id}/{dataset_id}",
|
repo_id=repo_id,
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=info_path,
|
|
||||||
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
||||||
repo_id=f"{community_id}/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
repo_type="dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
# stats
|
upload("info.json", "main")
|
||||||
api.upload_file(
|
upload("info.json", revision)
|
||||||
path_or_fileobj=stats_path,
|
upload("stats.safetensors", "main")
|
||||||
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
upload("stats.safetensors", revision)
|
||||||
repo_id=f"{community_id}/{dataset_id}",
|
upload("episode_data_index.safetensors", "main")
|
||||||
repo_type="dataset",
|
upload("episode_data_index.safetensors", revision)
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=stats_path,
|
|
||||||
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
||||||
repo_id=f"{community_id}/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=ep_data_idx_path,
|
|
||||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
||||||
repo_id=f"{community_id}/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=ep_data_idx_path,
|
|
||||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
||||||
repo_id=f"{community_id}/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# copy in tests folder, the first episode and the meta_data directory
|
|
||||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
|
||||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
|
||||||
f"tests/data/{community_id}/{dataset_id}/train"
|
|
||||||
)
|
|
||||||
if Path(f"tests/data/{community_id}/{dataset_id}/meta_data").exists():
|
|
||||||
shutil.rmtree(f"tests/data/{community_id}/{dataset_id}/meta_data")
|
|
||||||
shutil.copytree(meta_data_dir, f"tests/data/{community_id}/{dataset_id}/meta_data")
|
|
||||||
|
|
||||||
|
|
||||||
def push_dataset_to_hub(
|
def push_dataset_to_hub(
|
||||||
|
data_dir: Path,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
root: Path,
|
raw_format: str | None,
|
||||||
fps: int | None,
|
community_id: str,
|
||||||
dataset_folder: Path | None = None,
|
revision: str,
|
||||||
dry_run: bool = False,
|
dry_run: bool,
|
||||||
revision: str = "v1.1",
|
save_to_disk: bool,
|
||||||
community_id: str = "lerobot",
|
tests_data_dir: Path,
|
||||||
no_preprocess: bool = False,
|
save_tests_to_disk: bool,
|
||||||
path_save_to_disk: str | None = None,
|
fps: int,
|
||||||
**kwargs,
|
video: bool,
|
||||||
) -> None:
|
debug: bool,
|
||||||
"""
|
):
|
||||||
Download a raw dataset if needed or access a local raw dataset, detect the raw format (e.g. aloha, pusht, umi) and process it accordingly in a common data format which is then pushed to the Hugging Face Hub.
|
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||||
|
|
||||||
Args:
|
out_dir = data_dir / community_id / dataset_id
|
||||||
dataset_id (str): The ID of the dataset.
|
meta_data_dir = out_dir / "meta_data"
|
||||||
root (Path): The root directory where the dataset will be downloaded.
|
videos_dir = out_dir / "videos"
|
||||||
fps (int | None): The desired frames per second for the dataset.
|
|
||||||
dataset_folder (Path | None, optional): The path to the dataset folder. If not provided, the dataset will be downloaded using the dataset ID. Defaults to None.
|
|
||||||
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
|
||||||
revision (str, optional): Version of the `push_dataset_to_hub.py` codebase used to preprocess the dataset. Defaults to "v1.1".
|
|
||||||
community_id (str, optional): The ID of the community. Defaults to "lerobot".
|
|
||||||
no_preprocess (bool, optional): If True, does not preprocesses the dataset. Defaults to False.
|
|
||||||
path_save_to_disk (str | None, optional): The path to save the dataset to disk. Works when `dry_run` is True, which allows to only save on disk without uploading. By default, the dataset is not saved on disk.
|
|
||||||
**kwargs: Additional keyword arguments for the preprocessor init method.
|
|
||||||
|
|
||||||
|
tests_out_dir = tests_data_dir / community_id / dataset_id
|
||||||
|
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||||
|
|
||||||
"""
|
if out_dir.exists():
|
||||||
if dataset_folder is None:
|
shutil.rmtree(out_dir)
|
||||||
dataset_folder = download_raw(root=root, dataset_id=dataset_id)
|
|
||||||
|
|
||||||
if not no_preprocess:
|
if tests_out_dir.exists():
|
||||||
processor = guess_dataset_type(dataset_folder=dataset_folder, fps=fps, **kwargs)
|
shutil.rmtree(tests_out_dir)
|
||||||
data_dict, episode_data_index = processor.preprocess()
|
|
||||||
hf_dataset = processor.to_hf_dataset(data_dict)
|
|
||||||
|
|
||||||
info = {
|
if not raw_dir.exists():
|
||||||
"fps": processor.fps,
|
download_raw(raw_dir, dataset_id)
|
||||||
}
|
|
||||||
stats: dict[str, dict[str, torch.Tensor]] = compute_stats(hf_dataset)
|
|
||||||
|
|
||||||
push_lerobot_dataset_to_hub(
|
if raw_format is None:
|
||||||
hf_dataset=hf_dataset,
|
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||||
episode_data_index=episode_data_index,
|
raise NotImplementedError()
|
||||||
info=info,
|
# raw_format = auto_find_raw_format(raw_dir)
|
||||||
stats=stats,
|
|
||||||
root=root,
|
|
||||||
revision=revision,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
community_id=community_id,
|
|
||||||
dry_run=dry_run,
|
|
||||||
)
|
|
||||||
if path_save_to_disk:
|
|
||||||
hf_dataset.with_format("torch").save_to_disk(dataset_path=str(path_save_to_disk))
|
|
||||||
|
|
||||||
processor.cleanup()
|
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||||
|
|
||||||
|
# convert dataset from original raw format to LeRobot format
|
||||||
|
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||||
|
|
||||||
class DatasetProcessor(Protocol):
|
stats = compute_stats(hf_dataset)
|
||||||
"""A class for processing datasets.
|
|
||||||
|
|
||||||
This class provides methods for validating, preprocessing, and converting datasets.
|
if save_to_disk:
|
||||||
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
|
hf_dataset.save_to_disk(str(out_dir))
|
||||||
|
|
||||||
Args:
|
if not dry_run or save_to_disk:
|
||||||
folder_path (str): The path to the folder containing the dataset.
|
# mandatory for upload
|
||||||
fps (int | None): The frames per second of the dataset. If None, the default value is used.
|
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, folder_path: str, fps: int | None, *args, **kwargs) -> None: ...
|
if not dry_run:
|
||||||
|
repo_id = f"{community_id}/{dataset_id}"
|
||||||
|
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||||
|
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||||
|
push_meta_data_to_hub(repo_id, meta_data_dir)
|
||||||
|
if video:
|
||||||
|
push_meta_data_to_hub(repo_id, videos_dir)
|
||||||
|
|
||||||
def is_valid(self) -> bool:
|
if save_tests_to_disk:
|
||||||
"""Check if the dataset is valid.
|
# get the first episode
|
||||||
|
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||||
|
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||||
|
|
||||||
Returns:
|
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||||
bool: True if the dataset is valid, False otherwise.
|
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def preprocess(self) -> tuple[dict, dict]:
|
# copy meta data to tests directory
|
||||||
"""Preprocess the dataset.
|
if Path(tests_meta_data_dir).exists():
|
||||||
|
shutil.rmtree(tests_meta_data_dir)
|
||||||
Returns:
|
shutil.copytree(meta_data_dir, tests_meta_data_dir)
|
||||||
tuple[dict, dict]: A tuple containing two dictionaries representing the preprocessed data.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def to_hf_dataset(self, data_dict: dict) -> Dataset:
|
|
||||||
"""Convert the preprocessed data to a Hugging Face dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_dict (dict): The preprocessed data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dataset: The converted Hugging Face dataset.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fps(self) -> int:
|
|
||||||
"""Get the frames per second of the dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The frames per second.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Clean up any resources used by the dataset processor."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def guess_dataset_type(dataset_folder: Path, **processor_kwargs) -> DatasetProcessor:
|
|
||||||
if (processor := AlohaProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
||||||
return processor
|
|
||||||
if (processor := XarmProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
||||||
return processor
|
|
||||||
if (processor := PushTProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
||||||
return processor
|
|
||||||
if (processor := UmiProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
||||||
return processor
|
|
||||||
# TODO: Propose a registration mechanism for new dataset types
|
|
||||||
raise ValueError(f"Could not guess dataset type for folder {dataset_folder}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
parser = argparse.ArgumentParser()
|
||||||
Main function to process command line arguments and push dataset to Hugging Face Hub.
|
|
||||||
|
|
||||||
Parses command line arguments to get dataset details and conditions under which the dataset
|
|
||||||
is processed and pushed. It manages dataset preparation and uploading based on the user-defined parameters.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Push a dataset to the Hugging Face Hub with optional parameters for customization.",
|
|
||||||
epilog="""
|
|
||||||
Example usage:
|
|
||||||
python -m lerobot.scripts.push_dataset_to_hub --dataset-folder /path/to/dataset --dataset-id example_dataset --root /path/to/root --dry-run --revision v2.0 --community-id example_community --fps 30 --path-save-to-disk /path/to/save --no-preprocess
|
|
||||||
|
|
||||||
This processes and optionally pushes 'example_dataset' located in '/path/to/dataset' to Hugging Face Hub,
|
|
||||||
with various parameters to control the processing and uploading behavior.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset-folder",
|
"--data-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
required=True,
|
||||||
help="The filesystem path to the dataset folder. If not provided, the dataset must be identified and managed by other means.",
|
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset-id",
|
"--dataset-id",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Unique identifier for the dataset to be processed and uploaded.",
|
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--root", type=Path, required=True, help="Root directory where the dataset operations are managed."
|
"--raw-format",
|
||||||
)
|
type=str,
|
||||||
parser.add_argument(
|
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
||||||
"--dry-run",
|
|
||||||
action="store_true",
|
|
||||||
help="Simulate the push process without uploading any data, for testing purposes.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--community-id",
|
"--community-id",
|
||||||
@@ -297,41 +217,57 @@ def main():
|
|||||||
default="lerobot",
|
default="lerobot",
|
||||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--fps",
|
|
||||||
type=int,
|
|
||||||
help="Target frame rate for video or image sequence datasets. Optional and applicable only if the dataset includes temporal media.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--revision",
|
"--revision",
|
||||||
type=str,
|
type=str,
|
||||||
default="v1.0",
|
default="v1.2",
|
||||||
help="Dataset version identifier to manage different iterations of the dataset.",
|
help="Codebase version used to generate the dataset.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-preprocess",
|
"--dry-run",
|
||||||
action="store_true",
|
type=int,
|
||||||
help="Does not preprocess the dataset, set this flag if you only want dowload the dataset raw.",
|
default=0,
|
||||||
|
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--path-save-to-disk",
|
"--save-to-disk",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Save the dataset in the directory specified by `--data-dir`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tests-data-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
help="Optional path where the processed dataset can be saved locally.",
|
default="tests/data",
|
||||||
|
help="Directory containing tests artifacts datasets.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-tests-to-disk",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fps",
|
||||||
|
type=int,
|
||||||
|
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--video",
|
||||||
|
type=int,
|
||||||
|
# TODO(rcadene): enable when video PR merges
|
||||||
|
default=0,
|
||||||
|
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Debug mode process the first episode only.",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
push_dataset_to_hub(**vars(args))
|
||||||
push_dataset_to_hub(
|
|
||||||
dataset_folder=args.dataset_folder,
|
|
||||||
dataset_id=args.dataset_id,
|
|
||||||
root=args.root,
|
|
||||||
fps=args.fps,
|
|
||||||
dry_run=args.dry_run,
|
|
||||||
community_id=args.community_id,
|
|
||||||
revision=args.revision,
|
|
||||||
no_preprocess=args.no_preprocess,
|
|
||||||
path_save_to_disk=args.path_save_to_disk,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ from safetensors.torch import load_file
|
|||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import (
|
||||||
|
LeRobotDataset,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
compute_stats,
|
compute_stats,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
@@ -22,8 +24,7 @@ from lerobot.common.datasets.utils import (
|
|||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||||
|
|||||||
Reference in New Issue
Block a user