diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a466cff7..038b4458 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,6 +10,7 @@ on: - "examples/**" - ".github/**" - "poetry.lock" + - "Makefile" push: branches: - main @@ -19,6 +20,7 @@ on: - "examples/**" - ".github/**" - "poetry.lock" + - "Makefile" jobs: pytest: @@ -32,8 +34,8 @@ jobs: with: lfs: true # Ensure LFS files are pulled - - name: Install EGL - run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev + - name: Install apt dependencies + run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg - name: Install poetry run: | @@ -70,6 +72,9 @@ jobs: with: lfs: true # Ensure LFS files are pulled + - name: Install apt dependencies + run: sudo apt-get update && sudo apt-get install -y ffmpeg + - name: Install poetry run: | pipx install poetry && poetry config virtualenvs.in-project true @@ -104,7 +109,7 @@ jobs: with: lfs: true # Ensure LFS files are pulled - - name: Install EGL + - name: Install apt dependencies run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev - name: Install poetry diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 00000000..b406d43b --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,18 @@ +on: + push: + +name: Secret Leaks + +permissions: + contents: read + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main diff --git a/Makefile b/Makefile index 33f3edf2..9bac437d 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ PYTHON_PATH := $(shell which python) # If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python POETRY_CHECK := $(shell command -v poetry) ifneq ($(POETRY_CHECK),) - PYTHON_PATH := $(shell poetry run which python) + PYTHON_PATH := $(shell poetry run which python) endif export PATH := $(dir $(PYTHON_PATH)):$(PATH) @@ -46,6 +46,7 @@ test-act-ete-train: policy.n_action_steps=20 \ policy.chunk_size=20 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/act/ test-act-ete-eval: @@ -73,6 +74,7 @@ test-act-ete-train-amp: policy.chunk_size=20 \ training.batch_size=2 \ hydra.run.dir=tests/outputs/act_amp/ \ + training.image_transforms.enable=true \ use_amp=true test-act-ete-eval-amp: @@ -100,6 +102,7 @@ test-diffusion-ete-train: training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/diffusion/ test-diffusion-ete-eval: @@ -127,6 +130,7 @@ test-tdmpc-ete-train: training.save_checkpoint=true \ training.save_freq=2 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/tdmpc/ test-tdmpc-ete-eval: @@ -159,5 +163,6 @@ test-act-pusht-tutorial: training.save_model=true \ training.save_freq=2 \ training.batch_size=2 \ + training.image_transforms.enable=true \ hydra.run.dir=tests/outputs/act_pusht/ rm lerobot/configs/policy/created_by_Makefile.yaml diff --git a/README.md b/README.md index 12ebe8d0..d76969bc 100644 --- a/README.md +++ b/README.md @@ -228,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` -Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with: +Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with: ```bash python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id aloha_static_pingpong_test \ ---raw-format aloha_hdf5 \ ---community-id lerobot +--raw-dir data/aloha_static_pingpong_test_raw \ +--out-dir data \ +--repo-id lerobot/aloha_static_pingpong_test \ +--raw-format aloha_hdf5 ``` See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions. diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index 70a5b505..db9840a7 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -46,7 +46,7 @@ defaults: - policy: diffusion ``` -This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_. +This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_. Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`. diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py new file mode 100644 index 00000000..bdcc6d7b --- /dev/null +++ b/examples/6_add_image_transforms.py @@ -0,0 +1,52 @@ +""" +This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data +augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and +transforms are applied to the observation images before they are returned in the dataset's __get_item__. +""" + +from pathlib import Path + +from torchvision.transforms import ToPILImage, v2 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +dataset_repo_id = "lerobot/aloha_static_tape" + +# Create a LeRobotDataset with no transformations +dataset = LeRobotDataset(dataset_repo_id) +# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)` + +# Get the index of the first observation in the first episode +first_idx = dataset.episode_data_index["from"][0].item() + +# Get the frame corresponding to the first camera +frame = dataset[first_idx][dataset.camera_keys[0]] + + +# Define the transformations +transforms = v2.Compose( + [ + v2.ColorJitter(brightness=(0.5, 1.5)), + v2.ColorJitter(contrast=(0.5, 1.5)), + v2.RandomAdjustSharpness(sharpness_factor=2, p=1), + ] +) + +# Create another LeRobotDataset with the defined transformations +transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms) + +# Get a frame from the transformed dataset +transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] + +# Create a directory to store output images +output_dir = Path("outputs/image_transforms") +output_dir.mkdir(parents=True, exist_ok=True) + +# Save the original frame +to_pil = ToPILImage() +to_pil(frame).save(output_dir / "original_frame.png", quality=100) +print(f"Original frame saved to {output_dir / 'original_frame.png'}.") + +# Save the transformed frame +to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100) +print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.") diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4732f577..754bc91b 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -19,6 +19,7 @@ import torch from omegaconf import ListConfig, OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms def resolve_delta_timestamps(cfg): @@ -71,17 +72,37 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData resolve_delta_timestamps(cfg) - # TODO(rcadene): add data augmentations + image_transforms = None + if cfg.training.image_transforms.enable: + cfg_tf = cfg.training.image_transforms + image_transforms = get_image_transforms( + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, + ) if isinstance(cfg.dataset_repo_id, str): dataset = LeRobotDataset( cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps"), + image_transforms=image_transforms, ) else: dataset = MultiLeRobotDataset( - cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps") + cfg.dataset_repo_id, + split=split, + delta_timestamps=cfg.training.get("delta_timestamps"), + image_transforms=image_transforms, ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 58ae51b1..d680b987 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset): version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.version = version self.root = root self.split = split - self.transform = transform + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps # load data from hub or locally when root is provided # TODO(rcadene, aliberts): implement faster transfer @@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s, ) - if self.transform is not None: - item = self.transform(item) + if self.image_transforms is not None: + for cam in self.camera_keys: + item[cam] = self.image_transforms(item[cam]) return item @@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" ) @@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.version = version obj.root = root obj.split = split - obj.transform = transform + obj.image_transforms = transform obj.delta_timestamps = delta_timestamps obj.hf_dataset = hf_dataset obj.episode_data_index = episode_data_index @@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): root=root, split=split, delta_timestamps=delta_timestamps, - transform=transform, + image_transforms=image_transforms, ) for repo_id in repo_ids ] @@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self.version = version self.root = root self.split = split - self.transform = transform + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.stats = aggregate_stats(self._datasets) @@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): for data_key in self.disabled_data_keys: if data_key in item: del item[data_key] + return item def __repr__(self): @@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" ) diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py index 7074bcba..7974ab8e 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -14,156 +14,119 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This file contains all obsolete download scripts. They are centralized here to not have to load -useless dependencies when using datasets. +This file contains download scripts for raw datasets. + +Example of usage: +``` +python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \ +--raw-dir data/cadene/pusht_raw \ +--repo-id cadene/pusht_raw +``` """ -import io +import argparse import logging -import shutil +import warnings from pathlib import Path -import tqdm from huggingface_hub import snapshot_download -def download_raw(raw_dir, dataset_id): - if "aloha" in dataset_id or "image" in dataset_id: - download_hub(raw_dir, dataset_id) - elif "pusht" in dataset_id: - download_pusht(raw_dir) - elif "xarm" in dataset_id: - download_xarm(raw_dir) - elif "umi" in dataset_id: - download_umi(raw_dir) - else: - raise ValueError(dataset_id) +def download_raw(raw_dir: Path, repo_id: str): + # Check repo_id is well formated + if len(repo_id.split("/")) != 2: + raise ValueError( + f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'." + ) + user_id, dataset_id = repo_id.split("/") - -def download_and_extract_zip(url: str, destination_folder: Path) -> bool: - import zipfile - - import requests - - print(f"downloading from {url}") - response = requests.get(url, stream=True) - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) - - zip_file = io.BytesIO() - for chunk in response.iter_content(chunk_size=1024): - if chunk: - zip_file.write(chunk) - progress_bar.update(len(chunk)) - - progress_bar.close() - - zip_file.seek(0) - - with zipfile.ZipFile(zip_file, "r") as zip_ref: - zip_ref.extractall(destination_folder) - - -def download_pusht(raw_dir: str): - pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" + if not dataset_id.endswith("_raw"): + warnings.warn( + f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.", + stacklevel=1, + ) raw_dir = Path(raw_dir) - raw_dir.mkdir(parents=True, exist_ok=True) - download_and_extract_zip(pusht_url, raw_dir) - # file is created inside a useful "pusht" directory, so we move it out and delete the dir - zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr" - shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path) - shutil.rmtree(raw_dir / "pusht") - - -def download_xarm(raw_dir: Path): - """Download all xarm datasets at once""" - import zipfile - - import gdown - - raw_dir = Path(raw_dir) - raw_dir.mkdir(parents=True, exist_ok=True) - # from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py - url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - zip_path = raw_dir / "data.zip" - gdown.download(url, str(zip_path), quiet=False) - print("Extracting...") - with zipfile.ZipFile(str(zip_path), "r") as zip_f: - for pkl_path in zip_f.namelist(): - if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"): - zip_f.extract(member=pkl_path) - # move to corresponding raw directory - extract_dir = pkl_path.replace("/buffer.pkl", "") - raw_pkl_path = raw_dir / "buffer.pkl" - shutil.move(pkl_path, raw_pkl_path) - shutil.rmtree(extract_dir) - zip_path.unlink() - - -def download_hub(raw_dir: Path, dataset_id: str): - raw_dir = Path(raw_dir) + # Send warning if raw_dir isn't well formated + if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id: + warnings.warn( + f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.", + stacklevel=1, + ) raw_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}") - snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir) - logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}") + logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") + snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir) + logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") -def download_umi(raw_dir: Path): - url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip" - zarr_path = raw_dir / "cup_in_the_wild.zarr" +def download_all_raw_datasets(): + data_dir = Path("data") + repo_ids = [ + "cadene/pusht_image_raw", + "cadene/xarm_lift_medium_image_raw", + "cadene/xarm_lift_medium_replay_image_raw", + "cadene/xarm_push_medium_image_raw", + "cadene/xarm_push_medium_replay_image_raw", + "cadene/aloha_sim_insertion_human_image_raw", + "cadene/aloha_sim_insertion_scripted_image_raw", + "cadene/aloha_sim_transfer_cube_human_image_raw", + "cadene/aloha_sim_transfer_cube_scripted_image_raw", + "cadene/pusht_raw", + "cadene/xarm_lift_medium_raw", + "cadene/xarm_lift_medium_replay_raw", + "cadene/xarm_push_medium_raw", + "cadene/xarm_push_medium_replay_raw", + "cadene/aloha_sim_insertion_human_raw", + "cadene/aloha_sim_insertion_scripted_raw", + "cadene/aloha_sim_transfer_cube_human_raw", + "cadene/aloha_sim_transfer_cube_scripted_raw", + "cadene/aloha_mobile_cabinet_raw", + "cadene/aloha_mobile_chair_raw", + "cadene/aloha_mobile_elevator_raw", + "cadene/aloha_mobile_shrimp_raw", + "cadene/aloha_mobile_wash_pan_raw", + "cadene/aloha_mobile_wipe_wine_raw", + "cadene/aloha_static_battery_raw", + "cadene/aloha_static_candy_raw", + "cadene/aloha_static_coffee_raw", + "cadene/aloha_static_coffee_new_raw", + "cadene/aloha_static_cups_open_raw", + "cadene/aloha_static_fork_pick_up_raw", + "cadene/aloha_static_pingpong_test_raw", + "cadene/aloha_static_pro_pencil_raw", + "cadene/aloha_static_screw_driver_raw", + "cadene/aloha_static_tape_raw", + "cadene/aloha_static_thread_velcro_raw", + "cadene/aloha_static_towel_raw", + "cadene/aloha_static_vinh_cup_raw", + "cadene/aloha_static_vinh_cup_left_raw", + "cadene/aloha_static_ziploc_slide_raw", + "cadene/umi_cup_in_the_wild_raw", + ] + for repo_id in repo_ids: + raw_dir = data_dir / repo_id + download_raw(raw_dir, repo_id) - raw_dir = Path(raw_dir) - raw_dir.mkdir(parents=True, exist_ok=True) - download_and_extract_zip(url_cup_in_the_wild, zarr_path) + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--raw-dir", + type=Path, + required=True, + help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).", + ) + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).", + ) + args = parser.parse_args() + download_raw(**vars(args)) if __name__ == "__main__": - data_dir = Path("data") - dataset_ids = [ - "pusht_image", - "xarm_lift_medium_image", - "xarm_lift_medium_replay_image", - "xarm_push_medium_image", - "xarm_push_medium_replay_image", - "aloha_sim_insertion_human_image", - "aloha_sim_insertion_scripted_image", - "aloha_sim_transfer_cube_human_image", - "aloha_sim_transfer_cube_scripted_image", - "pusht", - "xarm_lift_medium", - "xarm_lift_medium_replay", - "xarm_push_medium", - "xarm_push_medium_replay", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", - "aloha_mobile_cabinet", - "aloha_mobile_chair", - "aloha_mobile_elevator", - "aloha_mobile_shrimp", - "aloha_mobile_wash_pan", - "aloha_mobile_wipe_wine", - "aloha_static_battery", - "aloha_static_candy", - "aloha_static_coffee", - "aloha_static_coffee_new", - "aloha_static_cups_open", - "aloha_static_fork_pick_up", - "aloha_static_pingpong_test", - "aloha_static_pro_pencil", - "aloha_static_screw_driver", - "aloha_static_tape", - "aloha_static_thread_velcro", - "aloha_static_towel", - "aloha_static_vinh_cup", - "aloha_static_vinh_cup_left", - "aloha_static_ziploc_slide", - "umi_cup_in_the_wild", - ] - for dataset_id in dataset_ids: - raw_dir = data_dir / f"{dataset_id}_raw" - download_raw(raw_dir, dataset_id) + main() diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index 1c2f066e..024045a0 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -30,6 +30,7 @@ from PIL import Image as PILImage from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool: assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): # only frames from simulation are uncompressed compressed_images = "sim" not in raw_dir.name - hdf5_files = list(raw_dir.glob("*.hdf5")) - ep_dicts = [] - episode_data_index = {"from": [], "to": []} + hdf5_files = sorted(raw_dir.glob("episode_*.hdf5")) + num_episodes = len(hdf5_files) - id_from = 0 - for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)): + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx in tqdm.tqdm(ep_ids): + ep_path = hdf5_files[ep_idx] with h5py.File(ep_path, "r") as ep: num_frames = ep["/action"].shape[0] @@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): assert isinstance(ep_idx, int) ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from += num_frames - gc.collect() - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - return data_dict, episode_data_index + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict def to_hf_dataset(data_dict, video) -> Dataset: @@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset: return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) if fps is None: fps = 50 - data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) - hf_dataset = to_hf_dataset(data_dir, video) - + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) + hf_dataset = to_hf_dataset(data_dict, video) + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py similarity index 90% rename from lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py rename to lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py index 4a21bc2d..1dc2e67e 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py @@ -17,7 +17,6 @@ Contains utilities to process raw data format from dora-record """ -import logging import re from pathlib import Path @@ -26,10 +25,10 @@ import torch from datasets import Dataset, Features, Image, Sequence, Value from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame -from lerobot.common.utils.utils import init_logging def check_format(raw_dir) -> bool: @@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool: return True -def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): # Load data stream that will be used as reference for the timestamps synchronization reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) if len(reference_files) == 0: @@ -122,8 +121,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}") # Create symlink to raw videos directory (that needs to be absolute not relative) - out_dir.mkdir(parents=True, exist_ok=True) - videos_dir = out_dir / "videos" + videos_dir.parent.mkdir(parents=True, exist_ok=True) videos_dir.symlink_to((raw_dir / "videos").absolute()) # sanity check the video paths are well formated @@ -156,16 +154,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): else: raise ValueError(key) - # Get the episode index containing for each unique episode index - first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index() - from_ = first_ep_index_df["start_index"].tolist() - to_ = from_[1:] + [len(df)] - episode_data_index = { - "from": from_, - "to": to_, - } - - return data_dict, episode_data_index + return data_dict def to_hf_dataset(data_dict, video) -> Dataset: @@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset: return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): - init_logging() - - if debug: - logging.warning("debug=True not implemented. Falling back to debug=False.") - +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) @@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru if not video: raise NotImplementedError() - data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps) + data_df = load_from_raw(raw_dir, videos_dir, fps, episodes) hf_dataset = to_hf_dataset(data_df, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 8133a36a..d9c7eb65 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -27,6 +27,7 @@ from PIL import Image as PILImage from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -53,7 +54,7 @@ def check_format(raw_dir): assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): try: import pymunk from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely @@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) episode_ids = torch.from_numpy(zarr_data.get_episode_idxs()) - num_episodes = zarr_data.meta["episode_ends"].shape[0] assert len( {zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118 ), "Some data type dont have the same number of total frames." @@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): states = torch.from_numpy(zarr_data["state"]) actions = torch.from_numpy(zarr_data["action"]) - ep_dicts = [] - episode_data_index = {"from": [], "to": []} + # load data indices from which each episode starts and ends + from_ids, to_ids = [], [] + from_idx = 0 + for to_idx in zarr_data.meta["episode_ends"]: + from_ids.append(from_idx) + to_ids.append(to_idx) + from_idx = to_idx - id_from = 0 - for ep_idx in tqdm.tqdm(range(num_episodes)): - id_to = zarr_data.meta["episode_ends"][ep_idx] - num_frames = id_to - id_from + num_episodes = len(from_ids) + + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)): + from_idx = from_ids[selected_ep_idx] + to_idx = to_ids[selected_ep_idx] + num_frames = to_idx - from_idx # sanity check - assert (episode_ids[id_from:id_to] == ep_idx).all() + assert (episode_ids[from_idx:to_idx] == ep_idx).all() # get image - image = imgs[id_from:id_to] + image = imgs[from_idx:to_idx] assert image.min() >= 0.0 assert image.max() <= 255.0 image = image.type(torch.uint8) # get state - state = states[id_from:id_to] + state = states[from_idx:to_idx] agent_pos = state[:, :2] block_pos = state[:, 2:4] block_angle = state[:, 4] @@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): img_key = "observation.image" if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict["observation.state"] = agent_pos - ep_dict["action"] = actions[id_from:id_to] + ep_dict["action"] = actions[from_idx:to_idx] ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps @@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]]) ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from += num_frames - - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - return data_dict, episode_data_index + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict def to_hf_dataset(data_dict, video): @@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video): return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) if fps is None: fps = 10 - data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) hf_dataset = to_hf_dataset(data_dict, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index cab2bdc5..6cd80c61 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -19,7 +19,6 @@ import logging import shutil from pathlib import Path -import numpy as np import torch import tqdm import zarr @@ -29,6 +28,7 @@ from PIL import Image as PILImage from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool: assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) -def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray: - # Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374 - from numba import jit - - @jit(nopython=True) - def _get_episode_idxs(episode_ends): - result = np.zeros((episode_ends[-1],), dtype=np.int64) - start_idx = 0 - for episode_number, end_idx in enumerate(episode_ends): - result[start_idx:end_idx] = episode_number - start_idx = end_idx - return result - - return _get_episode_idxs(episode_ends) - - -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): zarr_path = raw_dir / "cup_in_the_wild.zarr" zarr_data = zarr.open(zarr_path, mode="r") @@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): episode_ends = zarr_data["meta/episode_ends"][:] num_episodes = episode_ends.shape[0] - episode_ids = torch.from_numpy(get_episode_idxs(episode_ends)) - # We convert it in torch tensor later because the jit function does not support torch tensors episode_ends = torch.from_numpy(episode_ends) + # load data indices from which each episode starts and ends + from_ids, to_ids = [], [] + from_idx = 0 + for to_idx in episode_ends: + from_ids.append(from_idx) + to_ids.append(to_idx) + from_idx = to_idx + ep_dicts = [] - episode_data_index = {"from": [], "to": []} - - id_from = 0 - for ep_idx in tqdm.tqdm(range(num_episodes)): - id_to = episode_ends[ep_idx] - num_frames = id_to - id_from - - # sanity heck - assert (episode_ids[id_from:id_to] == ep_idx).all() + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)): + from_idx = from_ids[selected_ep_idx] + to_idx = to_ids[selected_ep_idx] + num_frames = to_idx - from_idx # TODO(rcadene): save temporary images of the episode? - state = states[id_from:id_to] + state = states[from_idx:to_idx] ep_dict = {} # load 57MB of images in RAM (400x224x224x3 uint8) - imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to] + imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx] img_key = "observation.image" if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps - ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames) - ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames) - ep_dict["end_pose"] = end_pose[id_from:id_to] - ep_dict["start_pos"] = start_pos[id_from:id_to] - ep_dict["gripper_width"] = gripper_width[id_from:id_to] + ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames) + ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames) + ep_dict["end_pose"] = end_pose[from_idx:to_idx] + ep_dict["start_pos"] = start_pos[from_idx:to_idx] + ep_dict["gripper_width"] = gripper_width[from_idx:to_idx] ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - id_from += num_frames - - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - total_frames = id_from + total_frames = data_dict["frame_index"].shape[0] data_dict["index"] = torch.arange(0, total_frames, 1) - - return data_dict, episode_data_index + return data_dict def to_hf_dataset(data_dict, video): @@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video): return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) @@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru "Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM." ) - data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) hf_dataset = to_hf_dataset(data_dict, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index 899ebdde..57a36dba 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -27,6 +27,7 @@ from PIL import Image as PILImage from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -54,37 +55,42 @@ def check_format(raw_dir): assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict) -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): pkl_path = raw_dir / "buffer.pkl" with open(pkl_path, "rb") as f: pkl_data = pickle.load(f) - ep_dicts = [] - episode_data_index = {"from": [], "to": []} - - id_from = 0 - id_to = 0 - ep_idx = 0 - total_frames = pkl_data["actions"].shape[0] - for i in tqdm.tqdm(range(total_frames)): - id_to += 1 - - if not pkl_data["dones"][i]: + # load data indices from which each episode starts and ends + from_ids, to_ids = [], [] + from_idx, to_idx = 0, 0 + for done in pkl_data["dones"]: + to_idx += 1 + if not done: continue + from_ids.append(from_idx) + to_ids.append(to_idx) + from_idx = to_idx - num_frames = id_to - id_from + num_episodes = len(from_ids) - image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to]) + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)): + from_idx = from_ids[selected_ep_idx] + to_idx = to_ids[selected_ep_idx] + num_frames = to_idx - from_idx + + image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx]) image = einops.rearrange(image, "b c h w -> b h w c") - state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to]) - action = torch.tensor(pkl_data["actions"][id_from:id_to]) + state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx]) + action = torch.tensor(pkl_data["actions"][from_idx:to_idx]) # TODO(rcadene): we have a missing last frame which is the observation when the env is done # it is critical to have this frame for tdmpc to predict a "done observation/state" - # next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to]) - # next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to]) - next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to]) - next_done = torch.tensor(pkl_data["dones"][id_from:id_to]) + # next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx]) + # next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx]) + next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx]) + next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx]) ep_dict = {} @@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): img_key = "observation.image" if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict["next.done"] = next_done ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from = id_to - ep_idx += 1 - - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - return data_dict, episode_data_index + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict def to_hf_dataset(data_dict, video): @@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video): return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) if fps is None: fps = 15 - data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) hf_dataset = to_hf_dataset(data_dict, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py new file mode 100644 index 00000000..899f0d66 --- /dev/null +++ b/lerobot/common/datasets/transforms.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +from typing import Any, Callable, Dict, Sequence + +import torch +from torchvision.transforms import v2 +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2 import functional as F # noqa: N812 + + +class RandomSubsetApply(Transform): + """Apply a random subset of N transformations from a list of transformations. + + Args: + transforms: list of transformations. + p: represents the multinomial probabilities (with no replacement) used for sampling the transform. + If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms + have the same probability. + n_subset: number of transformations to apply. If ``None``, all transforms are applied. + Must be in [1, len(transforms)]. + random_order: apply transformations in a random order. + """ + + def __init__( + self, + transforms: Sequence[Callable], + p: list[float] | None = None, + n_subset: int | None = None, + random_order: bool = False, + ) -> None: + super().__init__() + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + if p is None: + p = [1] * len(transforms) + elif len(p) != len(transforms): + raise ValueError( + f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}" + ) + + if n_subset is None: + n_subset = len(transforms) + elif not isinstance(n_subset, int): + raise TypeError("n_subset should be an int or None") + elif not (1 <= n_subset <= len(transforms)): + raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") + + self.transforms = transforms + total = sum(p) + self.p = [prob / total for prob in p] + self.n_subset = n_subset + self.random_order = random_order + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + + selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset) + if not self.random_order: + selected_indices = selected_indices.sort().values + + selected_transforms = [self.transforms[i] for i in selected_indices] + + for transform in selected_transforms: + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + + return outputs + + def extra_repr(self) -> str: + return ( + f"transforms={self.transforms}, " + f"p={self.p}, " + f"n_subset={self.n_subset}, " + f"random_order={self.random_order}" + ) + + +class SharpnessJitter(Transform): + """Randomly change the sharpness of an image or video. + + Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly. + While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image, + SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of + augmentations as a result. + + A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness + by a factor of 2. + + If the input is a :class:`torch.Tensor`, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from + [max(0, 1 - sharpness), 1 + sharpness] or the given + [min, max]. Should be non negative numbers. + """ + + def __init__(self, sharpness: float | Sequence[float]) -> None: + super().__init__() + self.sharpness = self._check_input(sharpness) + + def _check_input(self, sharpness): + if isinstance(sharpness, (int, float)): + if sharpness < 0: + raise ValueError("If sharpness is a single number, it must be non negative.") + sharpness = [1.0 - sharpness, 1.0 + sharpness] + sharpness[0] = max(sharpness[0], 0.0) + elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: + sharpness = [float(v) for v in sharpness] + else: + raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") + + if not 0.0 <= sharpness[0] <= sharpness[1]: + raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") + + return float(sharpness[0]), float(sharpness[1]) + + def _generate_value(self, left: float, right: float) -> float: + return torch.empty(1).uniform_(left, right).item() + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) + + +def get_image_transforms( + brightness_weight: float = 1.0, + brightness_min_max: tuple[float, float] | None = None, + contrast_weight: float = 1.0, + contrast_min_max: tuple[float, float] | None = None, + saturation_weight: float = 1.0, + saturation_min_max: tuple[float, float] | None = None, + hue_weight: float = 1.0, + hue_min_max: tuple[float, float] | None = None, + sharpness_weight: float = 1.0, + sharpness_min_max: tuple[float, float] | None = None, + max_num_transforms: int | None = None, + random_order: bool = False, +): + def check_value(name, weight, min_max): + if min_max is not None: + if len(min_max) != 2: + raise ValueError( + f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided." + ) + if weight < 0.0: + raise ValueError( + f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})." + ) + + check_value("brightness", brightness_weight, brightness_min_max) + check_value("contrast", contrast_weight, contrast_min_max) + check_value("saturation", saturation_weight, saturation_min_max) + check_value("hue", hue_weight, hue_min_max) + check_value("sharpness", sharpness_weight, sharpness_min_max) + + weights = [] + transforms = [] + if brightness_min_max is not None and brightness_weight > 0.0: + weights.append(brightness_weight) + transforms.append(v2.ColorJitter(brightness=brightness_min_max)) + if contrast_min_max is not None and contrast_weight > 0.0: + weights.append(contrast_weight) + transforms.append(v2.ColorJitter(contrast=contrast_min_max)) + if saturation_min_max is not None and saturation_weight > 0.0: + weights.append(saturation_weight) + transforms.append(v2.ColorJitter(saturation=saturation_min_max)) + if hue_min_max is not None and hue_weight > 0.0: + weights.append(hue_weight) + transforms.append(v2.ColorJitter(hue=hue_min_max)) + if sharpness_min_max is not None and sharpness_weight > 0.0: + weights.append(sharpness_weight) + transforms.append(SharpnessJitter(sharpness=sharpness_min_max)) + + n_subset = len(transforms) + if max_num_transforms is not None: + n_subset = min(n_subset, max_num_transforms) + + if n_subset == 0: + return v2.Identity() + else: + # TODO(rcadene, aliberts): add v2.ToDtype float16? + return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 71d961a1..b76d9b67 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -238,5 +238,6 @@ class Logger: def log_video(self, video_path: str, step: int, mode: str = "train"): assert mode in {"train", "eval"} + assert self._wandb is not None wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4") self._wandb.log({f"{mode}/video": wandb_video}, step=step) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 273f4f75..e0482143 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -239,10 +239,8 @@ class DiffusionModel(nn.Module): global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) # run sampling - sample = self.conditional_sample(batch_size, global_cond=global_cond) + actions = self.conditional_sample(batch_size, global_cond=global_cond) - # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.config.n_action_steps diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index d638c541..9b055f7e 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -147,7 +147,7 @@ class Normalize(nn.Module): assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min) + batch[key] = (batch[key] - min) / (max - min + 1e-8) # normalize to [-1, 1] batch[key] = batch[key] * 2 - 1 else: diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 38738a90..4e9e87af 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -57,7 +57,7 @@ class Policy(Protocol): other items should be logging-friendly, native Python types. """ - def select_action(self, batch: dict[str, Tensor]): + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). When the model uses a history of observations, or outputs a sequence of actions, this method deals diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 7c873bf2..de9658e9 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): self._prev_mean: torch.Tensor | None = None @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]): + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) batch["observation.image"] = batch[self.input_image_key] diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 85b9ceea..6101df89 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -43,6 +43,40 @@ training: save_checkpoint: true num_workers: 4 batch_size: ??? + image_transforms: + # These transforms are all using standard torchvision.transforms.v2 + # You can find out how these transformations affect images here: + # https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html + # We use a custom RandomSubsetApply container to sample them. + # For each transform, the following parameters are available: + # weight: This represents the multinomial probability (with no replacement) + # used for sampling the transform. If the sum of the weights is not 1, + # they will be normalized. + # min_max: Lower & upper bound respectively used for sampling the transform's parameter + # (following uniform distribution) when it's applied. + # Set this flag to `true` to enable transforms during training + enable: false + # This is the maximum number of transforms (sampled from these below) that will be applied to each frame. + # It's an integer in the interval [1, number of available transforms]. + max_num_transforms: 3 + # By default, transforms are applied in Torchvision's suggested order (shown below). + # Set this to True to apply them in a random order. + random_order: false + brightness: + weight: 1 + min_max: [0.8, 1.2] + contrast: + weight: 1 + min_max: [0.8, 1.2] + saturation: + weight: 1 + min_max: [0.5, 1.5] + hue: + weight: 1 + min_max: [-0.05, 0.05] + sharpness: + weight: 1 + min_max: [0.8, 1.2] eval: n_episodes: 1 diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py index 4d8b4850..4d3cc291 100644 --- a/lerobot/scripts/display_sys_info.py +++ b/lerobot/scripts/display_sys_info.py @@ -13,39 +13,71 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Use this script to get a quick summary of your system config. +It should be able to run without any of LeRobot's dependencies or LeRobot itself installed. +""" + import platform -import huggingface_hub +HAS_HF_HUB = True +HAS_HF_DATASETS = True +HAS_NP = True +HAS_TORCH = True +HAS_LEROBOT = True -# import dataset -import numpy as np -import torch +try: + import huggingface_hub +except ImportError: + HAS_HF_HUB = False -from lerobot import __version__ as version +try: + import datasets +except ImportError: + HAS_HF_DATASETS = False -pt_version = torch.__version__ -pt_cuda_available = torch.cuda.is_available() -pt_cuda_available = torch.cuda.is_available() -cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A" +try: + import numpy as np +except ImportError: + HAS_NP = False + +try: + import torch +except ImportError: + HAS_TORCH = False + +try: + import lerobot +except ImportError: + HAS_LEROBOT = False + + +lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A" +hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A" +hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A" +np_version = np.__version__ if HAS_NP else "N/A" + +torch_version = torch.__version__ if HAS_TORCH else "N/A" +torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" +cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" # TODO(aliberts): refactor into an actual command `lerobot env` def display_sys_info() -> dict: """Run this to get basic system info to help for tracking issues & bugs.""" info = { - "`lerobot` version": version, + "`lerobot` version": lerobot_version, "Platform": platform.platform(), "Python version": platform.python_version(), - "Huggingface_hub version": huggingface_hub.__version__, - # TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged - # "Dataset version": dataset.__version__, - "Numpy version": np.__version__, - "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Huggingface_hub version": hf_hub_version, + "Dataset version": hf_datasets_version, + "Numpy version": np_version, + "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})", "Cuda version": cuda_version, "Using GPU in script?": "", - "Using distributed or parallel set-up in script?": "", + # "Using distributed or parallel set-up in script?": "", } - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") print(format_dict(info)) return info diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 784e9fc6..7bf8bde5 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -61,7 +61,7 @@ from huggingface_hub import snapshot_download from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError from PIL import Image as PILImage -from torch import Tensor +from torch import Tensor, nn from tqdm import trange from lerobot.common.datasets.factory import make_dataset @@ -99,13 +99,13 @@ def rollout( "reward": A (batch, sequence) tensor of rewards received for applying the actions. "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon environment termination/truncation). - "don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, + "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, the first True is followed by True's all the way till the end. This can be used for masking extraneous elements from the sequences above. Args: env: The batch of environments. - policy: The policy. + policy: The policy. Must be a PyTorch nn module. seeds: The environments are seeded once at the start of the rollout. If provided, this argument specifies the seeds for each of the environments. return_observations: Whether to include all observations in the returned rollout data. Observations @@ -116,6 +116,7 @@ def rollout( Returns: The dictionary described above. """ + assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." device = get_device_from_parameters(policy) # Reset the policy and environments. @@ -209,7 +210,7 @@ def eval_policy( policy: torch.nn.Module, n_episodes: int, max_episodes_rendered: int = 0, - video_dir: Path | None = None, + videos_dir: Path | None = None, return_episode_data: bool = False, start_seed: int | None = None, enable_progbar: bool = False, @@ -221,7 +222,7 @@ def eval_policy( policy: The policy. n_episodes: The number of episodes to evaluate. max_episodes_rendered: Maximum number of episodes to render into videos. - video_dir: Where to save rendered videos. + videos_dir: Where to save rendered videos. return_episode_data: Whether to return episode data for online training. Incorporates the data into the "episodes" key of the returned dictionary. start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the @@ -231,6 +232,10 @@ def eval_policy( Returns: Dictionary with metrics and data regarding the rollouts. """ + if max_episodes_rendered > 0 and not videos_dir: + raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") + + assert isinstance(policy, Policy) start = time.time() policy.eval() @@ -271,11 +276,16 @@ def eval_policy( if max_episodes_rendered > 0: ep_frames: list[np.ndarray] = [] - seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)) + if start_seed is None: + seeds = None + else: + seeds = range( + start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) + ) rollout_data = rollout( env, policy, - seeds=seeds, + seeds=list(seeds) if seeds else None, return_observations=return_episode_data, render_callback=render_frame if max_episodes_rendered > 0 else None, enable_progbar=enable_inner_progbar, @@ -285,7 +295,8 @@ def eval_policy( # this won't be included). n_steps = rollout_data["done"].shape[1] # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. - done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps) + done_indices = torch.argmax(rollout_data["done"].to(int), dim=1) + # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() @@ -296,8 +307,12 @@ def eval_policy( max_rewards.extend(batch_max_rewards.tolist()) batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") all_successes.extend(batch_successes.tolist()) - all_seeds.extend(seeds) + if seeds: + all_seeds.extend(seeds) + else: + all_seeds.append(None) + # FIXME: episode_data is either None or it doesn't exist if return_episode_data: this_episode_data = _compile_episode_data( rollout_data, @@ -347,8 +362,9 @@ def eval_policy( ): if n_episodes_rendered >= max_episodes_rendered: break - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4" + + videos_dir.mkdir(parents=True, exist_ok=True) + video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4" video_paths.append(str(video_path)) thread = threading.Thread( target=write_video, @@ -503,22 +519,20 @@ def _compile_episode_data( } -def eval( - pretrained_policy_path: str | None = None, +def main( + pretrained_policy_path: Path | None = None, hydra_cfg_path: str | None = None, + out_dir: str | None = None, config_overrides: list[str] | None = None, ): assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) - if hydra_cfg_path is None: - hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides) + if pretrained_policy_path is not None: + hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides) else: hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) - out_dir = ( - f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" - ) if out_dir is None: - raise NotImplementedError() + out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" # Check device is available device = get_safe_torch_device(hydra_cfg.device, log=True) @@ -534,10 +548,12 @@ def eval( logging.info("Making policy.") if hydra_cfg_path is None: - policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path)) else: # Note: We need the dataset stats to pass to the policy's normalization modules. policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) + + assert isinstance(policy, nn.Module) policy.eval() with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): @@ -546,7 +562,7 @@ def eval( policy, hydra_cfg.eval.n_episodes, max_episodes_rendered=10, - video_dir=Path(out_dir) / "eval", + videos_dir=Path(out_dir) / "videos", start_seed=hydra_cfg.seed, enable_progbar=True, enable_inner_progbar=True, @@ -586,6 +602,13 @@ if __name__ == "__main__": ), ) parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") + parser.add_argument( + "--out-dir", + help=( + "Where to save the evaluation outputs. If not provided, outputs are saved in " + "outputs/eval/{timestamp}_{env_name}_{policy_name}" + ), + ) parser.add_argument( "overrides", nargs="*", @@ -594,7 +617,7 @@ if __name__ == "__main__": args = parser.parse_args() if args.pretrained_policy_name_or_path is None: - eval(hydra_cfg_path=args.config, config_overrides=args.overrides) + main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides) else: try: pretrained_policy_path = Path( @@ -618,4 +641,8 @@ if __name__ == "__main__": "repo ID, nor is it an existing local directory." ) - eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides) + main( + pretrained_policy_path=pretrained_policy_path, + out_dir=args.out_dir, + config_overrides=args.overrides, + ) diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 52252b57..18714a40 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -18,57 +18,39 @@ Use this script to convert your dataset into LeRobot dataset format and upload i or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any installation of neural net specific packages like pytorch, tensorflow, jax. -Example: +Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub: ``` python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id pusht \ +--raw-dir data/pusht_raw \ --raw-format pusht_zarr \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/pusht python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id xarm_lift_medium \ +--raw-dir data/xarm_lift_medium_raw \ --raw-format xarm_pkl \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/xarm_lift_medium python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id aloha_sim_insertion_scripted \ +--raw-dir data/aloha_sim_insertion_scripted_raw \ --raw-format aloha_hdf5 \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/aloha_sim_insertion_scripted python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id umi_cup_in_the_wild \ +--raw-dir data/umi_cup_in_the_wild_raw \ --raw-format umi_zarr \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/umi_cup_in_the_wild ``` """ import argparse import json import shutil +import warnings from pathlib import Path +from typing import Any import torch -from huggingface_hub import HfApi +from huggingface_hub import HfApi, create_branch from safetensors.torch import save_file from lerobot.common.datasets.compute_stats import compute_stats @@ -77,15 +59,15 @@ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_r from lerobot.common.datasets.utils import flatten_dict -def get_from_raw_to_lerobot_format_fn(raw_format): +def get_from_raw_to_lerobot_format_fn(raw_format: str): if raw_format == "pusht_zarr": from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format elif raw_format == "umi_zarr": from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format elif raw_format == "aloha_hdf5": from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format - elif raw_format == "aloha_dora": - from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format + elif raw_format == "dora_parquet": + from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format elif raw_format == "xarm_pkl": from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format else: @@ -96,7 +78,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format): return from_raw_to_lerobot_format -def save_meta_data(info, stats, episode_data_index, meta_data_dir): +def save_meta_data( + info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path +): meta_data_dir.mkdir(parents=True, exist_ok=True) # save info @@ -114,7 +98,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir): save_file(episode_data_index, ep_data_idx_path) -def push_meta_data_to_hub(repo_id, meta_data_dir, revision): +def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None): """Expect all meta data files to be all stored in a single "meta_data" directory. On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root. """ @@ -128,7 +112,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision): ) -def push_videos_to_hub(repo_id, videos_dir, revision): +def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None): """Expect mp4 files to be all stored in a single "videos" directory. On the hugging face repositery, they will be uploaded in a "videos" directory at the root. """ @@ -144,39 +128,61 @@ def push_videos_to_hub(repo_id, videos_dir, revision): def push_dataset_to_hub( - data_dir: Path, - dataset_id: str, - raw_format: str | None, - community_id: str, - revision: str, - dry_run: bool, - save_to_disk: bool, - tests_data_dir: Path, - save_tests_to_disk: bool, - fps: int | None, - video: bool, - batch_size: int, - num_workers: int, - debug: bool, + raw_dir: Path, + raw_format: str, + repo_id: str, + push_to_hub: bool = True, + local_dir: Path | None = None, + fps: int | None = None, + video: bool = True, + batch_size: int = 32, + num_workers: int = 8, + episodes: list[int] | None = None, + force_override: bool = False, + cache_dir: Path = Path("/tmp"), + tests_data_dir: Path | None = None, ): - repo_id = f"{community_id}/{dataset_id}" + # Check repo_id is well formated + if len(repo_id.split("/")) != 2: + raise ValueError( + f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'." + ) + user_id, dataset_id = repo_id.split("/") - raw_dir = data_dir / f"{dataset_id}_raw" + # Robustify when `raw_dir` is str instead of Path + raw_dir = Path(raw_dir) + if not raw_dir.exists(): + raise NotADirectoryError( + f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:" + f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw" + ) - out_dir = data_dir / repo_id - meta_data_dir = out_dir / "meta_data" - videos_dir = out_dir / "videos" + if local_dir: + # Robustify when `local_dir` is str instead of Path + local_dir = Path(local_dir) - tests_out_dir = tests_data_dir / repo_id - tests_meta_data_dir = tests_out_dir / "meta_data" - tests_videos_dir = tests_out_dir / "videos" + # Send warning if local_dir isn't well formated + if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id: + warnings.warn( + f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.", + stacklevel=1, + ) - if out_dir.exists(): - shutil.rmtree(out_dir) + # Check we don't override an existing `local_dir` by mistake + if local_dir.exists(): + if force_override: + shutil.rmtree(local_dir) + else: + raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.") - if tests_out_dir.exists() and save_tests_to_disk: - shutil.rmtree(tests_out_dir) + meta_data_dir = local_dir / "meta_data" + videos_dir = local_dir / "videos" + else: + # Temporary directory used to store images, videos, meta_data + meta_data_dir = Path(cache_dir) / "meta_data" + videos_dir = Path(cache_dir) / "videos" + # Download the raw dataset if available if not raw_dir.exists(): download_raw(raw_dir, dataset_id) @@ -185,14 +191,14 @@ def push_dataset_to_hub( raise NotImplementedError() # raw_format = auto_find_raw_format(raw_dir) - from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) - # convert dataset from original raw format to LeRobot format - hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug) + from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) + hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( + raw_dir, videos_dir, fps, video, episodes + ) lerobot_dataset = LeRobotDataset.from_preloaded( repo_id=repo_id, - version=revision, hf_dataset=hf_dataset, episode_data_index=episode_data_index, info=info, @@ -200,102 +206,80 @@ def push_dataset_to_hub( ) stats = compute_stats(lerobot_dataset, batch_size, num_workers) - if save_to_disk: + if local_dir: hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved - hf_dataset.save_to_disk(str(out_dir / "train")) + hf_dataset.save_to_disk(str(local_dir / "train")) - if not dry_run or save_to_disk: + if push_to_hub or local_dir: # mandatory for upload save_meta_data(info, stats, episode_data_index, meta_data_dir) - if not dry_run: - hf_dataset.push_to_hub(repo_id, token=True, revision="main") - hf_dataset.push_to_hub(repo_id, token=True, revision=revision) - + if push_to_hub: + hf_dataset.push_to_hub(repo_id, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") - push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision) - if video: push_videos_to_hub(repo_id, videos_dir, revision="main") - push_videos_to_hub(repo_id, videos_dir, revision=revision) + create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) - if save_tests_to_disk: + if tests_data_dir: # get the first episode num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] test_hf_dataset = hf_dataset.select(range(num_items_first_ep)) test_hf_dataset = test_hf_dataset.with_format(None) - test_hf_dataset.save_to_disk(str(tests_out_dir / "train")) + test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train")) - save_meta_data(info, stats, episode_data_index, tests_meta_data_dir) + tests_meta_data = tests_data_dir / repo_id / "meta_data" + save_meta_data(info, stats, episode_data_index, tests_meta_data) # copy videos of first episode to tests directory episode_index = 0 + tests_videos_dir = tests_data_dir / repo_id / "videos" tests_videos_dir.mkdir(parents=True, exist_ok=True) for key in lerobot_dataset.video_frame_keys: fname = f"{key}_episode_{episode_index:06d}.mp4" shutil.copy(videos_dir / fname, tests_videos_dir / fname) - if not save_to_disk and out_dir.exists(): - # remove possible temporary files remaining in the output directory - shutil.rmtree(out_dir) + if local_dir is None: + # clear cache + shutil.rmtree(meta_data_dir) + shutil.rmtree(videos_dir) + + return lerobot_dataset def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--data-dir", + "--raw-dir", type=Path, required=True, - help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).", - ) - parser.add_argument( - "--dataset-id", - type=str, - required=True, - help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).", + help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).", ) + # TODO(rcadene): add automatic detection of the format parser.add_argument( "--raw-format", type=str, - help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.", + required=True, + help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).", ) parser.add_argument( - "--community-id", + "--repo-id", type=str, - default="lerobot", - help="Community or user ID under which the dataset will be hosted on the Hub.", + required=True, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", ) parser.add_argument( - "--revision", - type=str, - default=CODEBASE_VERSION, - help="Codebase version used to generate the dataset.", - ) - parser.add_argument( - "--dry-run", - type=int, - default=0, - help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.", - ) - parser.add_argument( - "--save-to-disk", - type=int, - default=1, - help="Save the dataset in the directory specified by `--data-dir`.", - ) - parser.add_argument( - "--tests-data-dir", + "--local-dir", type=Path, - default="tests/data", - help="Directory containing tests artifacts datasets.", + help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).", ) parser.add_argument( - "--save-tests-to-disk", + "--push-to-hub", type=int, default=1, - help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.", + help="Upload to hub.", ) parser.add_argument( "--fps", @@ -321,10 +305,21 @@ def main(): help="Number of processes of Dataloader for computing the dataset statistics.", ) parser.add_argument( - "--debug", + "--episodes", + type=int, + nargs="*", + help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.", + ) + parser.add_argument( + "--force-override", type=int, default=0, - help="Debug mode process the first episode only.", + help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.", + ) + parser.add_argument( + "--tests-data-dir", + type=Path, + help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).", ) args = parser.parse_args() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index e63a5633..125a1b41 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,6 +24,7 @@ import torch from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored +from torch import nn from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps @@ -150,6 +151,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): grad_norm = info["grad_norm"] lr = info["lr"] update_s = info["update_s"] + dataloading_s = info["dataloading_s"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. @@ -170,6 +172,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): f"lr:{lr:0.1e}", # in seconds f"updt_s:{update_s:.3f}", + f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io ] logging.info(" ".join(log_items)) @@ -290,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. + eval_env = None if cfg.training.eval_freq > 0: logging.info("make_env") eval_env = make_env(cfg) @@ -300,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dataset_stats=offline_dataset.stats if not cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) - + assert isinstance(policy, nn.Module) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) @@ -325,14 +329,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step): + _num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) + step_identifier = f"{step:0{_num_digits}d}" + if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + assert eval_env is not None eval_info = eval_policy( eval_env, policy, cfg.eval.n_episodes, - video_dir=Path(out_dir) / "eval", + videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}", max_episodes_rendered=4, start_seed=cfg.seed, ) @@ -350,9 +358,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No policy, optimizer, lr_scheduler, - identifier=str(step).zfill( - max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) - ), + identifier=step_identifier, ) logging.info("Resume training") @@ -382,7 +388,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No for _ in range(step, cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") + + start_time = time.perf_counter() batch = next(dl_iter) + dataloading_s = time.perf_counter() - start_time for key in batch: batch[key] = batch[key].to(device, non_blocking=True) @@ -397,6 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No use_amp=cfg.use_amp, ) + train_info["dataloading_s"] = dataloading_s + if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True) @@ -406,7 +417,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No step += 1 - eval_env.close() + if eval_env: + eval_env.close() logging.info("End of training") diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 138084ae..f947e610 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -66,28 +66,31 @@ import gc import logging import time from pathlib import Path +from typing import Iterator +import numpy as np import rerun as rr import torch +import torch.utils.data import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset class EpisodeSampler(torch.utils.data.Sampler): - def __init__(self, dataset, episode_index): + def __init__(self, dataset: LeRobotDataset, episode_index: int): from_idx = dataset.episode_data_index["from"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item() self.frame_ids = range(from_idx, to_idx) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.frame_ids) - def __len__(self): + def __len__(self) -> int: return len(self.frame_ids) -def to_hwc_uint8_numpy(chw_float32_torch): +def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 c, h, w = chw_float32_torch.shape @@ -106,6 +109,7 @@ def visualize_dataset( ws_port: int = 9087, save: bool = False, output_dir: Path | None = None, + root: Path | None = None, ) -> Path | None: if save: assert ( @@ -113,7 +117,7 @@ def visualize_dataset( ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id) + dataset = LeRobotDataset(repo_id, root=root) logging.info("Loading dataloader") episode_sampler = EpisodeSampler(dataset, episode_index) @@ -224,7 +228,8 @@ def main(): help=( "Mode of viewing between 'local' or 'distant'. " "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " - "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + "'distant' creates a server on the distant machine where the data is stored. " + "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." ), ) parser.add_argument( @@ -245,8 +250,8 @@ def main(): default=0, help=( "Save a .rrd file in the directory provided by `--output-dir`. " - "It also deactivates the spawning of a viewer. ", - "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", + "It also deactivates the spawning of a viewer. " + "Visualize the data by running `rerun path/to/file.rrd` on your local machine." ), ) parser.add_argument( diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py new file mode 100644 index 00000000..fa3c0ab2 --- /dev/null +++ b/lerobot/scripts/visualize_image_transforms.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize effects of image transforms for a given configuration. + +This script will generate examples of transformed images as they are output by LeRobot dataset. +Additionally, each individual transform can be visualized separately as well as examples of combined transforms + + +--- Usage Examples --- + +Increase hue jitter +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.hue.min_max=[-0.25,0.25] +``` + +Increase brightness & brightness weight +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.brightness.weight=10.0 \ + training.image_transforms.brightness.min_max=[1.0,2.0] +``` + +Blur images and disable saturation & hue +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.sharpness.weight=10.0 \ + training.image_transforms.sharpness.min_max=[0.0,1.0] \ + training.image_transforms.saturation.weight=0.0 \ + training.image_transforms.hue.weight=0.0 +``` + +Use all transforms with random order +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.max_num_transforms=5 \ + training.image_transforms.random_order=true +``` + +""" + +from pathlib import Path + +import hydra +from torchvision.transforms import ToPILImage + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms + +OUTPUT_DIR = Path("outputs/image_transforms") +N_EXAMPLES = 5 +to_pil = ToPILImage() + + +def save_config_all_transforms(cfg, original_frame, output_dir): + tf = get_image_transforms( + brightness_weight=cfg.brightness.weight, + brightness_min_max=cfg.brightness.min_max, + contrast_weight=cfg.contrast.weight, + contrast_min_max=cfg.contrast.min_max, + saturation_weight=cfg.saturation.weight, + saturation_min_max=cfg.saturation.min_max, + hue_weight=cfg.hue.weight, + hue_min_max=cfg.hue.min_max, + sharpness_weight=cfg.sharpness.weight, + sharpness_min_max=cfg.sharpness.min_max, + max_num_transforms=cfg.max_num_transforms, + random_order=cfg.random_order, + ) + + output_dir_all = output_dir / "all" + output_dir_all.mkdir(parents=True, exist_ok=True) + + for i in range(1, N_EXAMPLES + 1): + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100) + + print("Combined transforms examples saved to:") + print(f" {output_dir_all}") + + +def save_config_single_transforms(cfg, original_frame, output_dir): + transforms = [ + "brightness", + "contrast", + "saturation", + "hue", + "sharpness", + ] + print("Individual transforms examples saved to:") + for transform in transforms: + kwargs = { + f"{transform}_weight": cfg[f"{transform}"].weight, + f"{transform}_min_max": cfg[f"{transform}"].min_max, + } + tf = get_image_transforms(**kwargs) + output_dir_single = output_dir / f"{transform}" + output_dir_single.mkdir(parents=True, exist_ok=True) + + for i in range(1, N_EXAMPLES + 1): + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100) + + print(f" {output_dir_single}") + + +@hydra.main(version_base="1.2", config_name="default", config_path="../configs") +def visualize_transforms(cfg): + dataset = LeRobotDataset(cfg.dataset_repo_id) + + output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1] + output_dir.mkdir(parents=True, exist_ok=True) + + # Get 1st frame from 1st camera of 1st episode + original_frame = dataset[0][dataset.camera_keys[0]] + to_pil(original_frame).save(output_dir / "original_frame.png", quality=100) + print("\nOriginal frame saved to:") + print(f" {output_dir / 'original_frame.png'}.") + + save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir) + save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir) + + +if __name__ == "__main__": + visualize_transforms() diff --git a/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors new file mode 100644 index 00000000..77699dab --- /dev/null +++ b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a +size 3686488 diff --git a/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors new file mode 100644 index 00000000..13f1033f --- /dev/null +++ b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f +size 40551392 diff --git a/tests/scripts/save_image_transforms_to_safetensors.py b/tests/scripts/save_image_transforms_to_safetensors.py new file mode 100644 index 00000000..9d024a01 --- /dev/null +++ b/tests/scripts/save_image_transforms_to_safetensors.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import torch +from safetensors.torch import save_file + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms +from lerobot.common.utils.utils import init_hydra_config, seeded_context +from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID +from tests.utils import DEFAULT_CONFIG_PATH + + +def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path): + cfg = init_hydra_config(DEFAULT_CONFIG_PATH) + cfg_tf = cfg.training.image_transforms + default_tf = get_image_transforms( + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, + ) + + with seeded_context(1337): + img_tf = default_tf(original_frame) + + save_file({"default": img_tf}, output_dir / "default_transforms.safetensors") + + +def save_single_transforms(original_frame: torch.Tensor, output_dir: Path): + transforms = { + "brightness": [(0.5, 0.5), (2.0, 2.0)], + "contrast": [(0.5, 0.5), (2.0, 2.0)], + "saturation": [(0.5, 0.5), (2.0, 2.0)], + "hue": [(-0.25, -0.25), (0.25, 0.25)], + "sharpness": [(0.5, 0.5), (2.0, 2.0)], + } + + frames = {"original_frame": original_frame} + for transform, values in transforms.items(): + for min_max in values: + kwargs = { + f"{transform}_weight": 1.0, + f"{transform}_min_max": min_max, + } + tf = get_image_transforms(**kwargs) + key = f"{transform}_{min_max[0]}_{min_max[1]}" + frames[key] = tf(original_frame) + + save_file(frames, output_dir / "single_transforms.safetensors") + + +def main(): + dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None) + output_dir = Path(ARTIFACT_DIR) + output_dir.mkdir(parents=True, exist_ok=True) + original_frame = dataset[0][dataset.camera_keys[0]] + + save_single_transforms(original_frame, output_dir) + save_default_config_transform(original_frame, output_dir) + + +if __name__ == "__main__": + main() diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py new file mode 100644 index 00000000..ba6d972f --- /dev/null +++ b/tests/test_image_transforms.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from safetensors.torch import load_file +from torchvision.transforms import v2 +from torchvision.transforms.v2 import functional as F # noqa: N812 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms +from lerobot.common.utils.utils import init_hydra_config, seeded_context +from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel + +ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors") +DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp" + + +def load_png_to_tensor(path: Path): + return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1) + + +@pytest.fixture +def img(): + dataset = LeRobotDataset(DATASET_REPO_ID) + return dataset[0][dataset.camera_keys[0]] + + +@pytest.fixture +def img_random(): + return torch.rand(3, 480, 640) + + +@pytest.fixture +def color_jitters(): + return [ + v2.ColorJitter(brightness=0.5), + v2.ColorJitter(contrast=0.5), + v2.ColorJitter(saturation=0.5), + ] + + +@pytest.fixture +def single_transforms(): + return load_file(ARTIFACT_DIR / "single_transforms.safetensors") + + +@pytest.fixture +def default_transforms(): + return load_file(ARTIFACT_DIR / "default_transforms.safetensors") + + +def test_get_image_transforms_no_transform(img): + tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0) + torch.testing.assert_close(tf_actual(img), img) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_brightness(img, min_max): + tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max) + tf_expected = v2.ColorJitter(brightness=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_contrast(img, min_max): + tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max) + tf_expected = v2.ColorJitter(contrast=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_saturation(img, min_max): + tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max) + tf_expected = v2.ColorJitter(saturation=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)]) +def test_get_image_transforms_hue(img, min_max): + tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max) + tf_expected = v2.ColorJitter(hue=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_sharpness(img, min_max): + tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max) + tf_expected = SharpnessJitter(sharpness=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +def test_get_image_transforms_max_num_transforms(img): + tf_actual = get_image_transforms( + brightness_min_max=(0.5, 0.5), + contrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=False, + ) + tf_expected = v2.Compose( + [ + v2.ColorJitter(brightness=(0.5, 0.5)), + v2.ColorJitter(contrast=(0.5, 0.5)), + v2.ColorJitter(saturation=(0.5, 0.5)), + v2.ColorJitter(hue=(0.5, 0.5)), + SharpnessJitter(sharpness=(0.5, 0.5)), + ] + ) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@require_x86_64_kernel +def test_get_image_transforms_random_order(img): + out_imgs = [] + tf = get_image_transforms( + brightness_min_max=(0.5, 0.5), + contrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=True, + ) + with seeded_context(1337): + for _ in range(10): + out_imgs.append(tf(img)) + + for i in range(1, len(out_imgs)): + with pytest.raises(AssertionError): + torch.testing.assert_close(out_imgs[0], out_imgs[i]) + + +@pytest.mark.parametrize( + "transform, min_max_values", + [ + ("brightness", [(0.5, 0.5), (2.0, 2.0)]), + ("contrast", [(0.5, 0.5), (2.0, 2.0)]), + ("saturation", [(0.5, 0.5), (2.0, 2.0)]), + ("hue", [(-0.25, -0.25), (0.25, 0.25)]), + ("sharpness", [(0.5, 0.5), (2.0, 2.0)]), + ], +) +def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms): + for min_max in min_max_values: + kwargs = { + f"{transform}_weight": 1.0, + f"{transform}_min_max": min_max, + } + tf = get_image_transforms(**kwargs) + actual = tf(img) + key = f"{transform}_{min_max[0]}_{min_max[1]}" + expected = single_transforms[key] + torch.testing.assert_close(actual, expected) + + +@require_x86_64_kernel +def test_backward_compatibility_default_config(img, default_transforms): + cfg = init_hydra_config(DEFAULT_CONFIG_PATH) + cfg_tf = cfg.training.image_transforms + default_tf = get_image_transforms( + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, + ) + + with seeded_context(1337): + actual = default_tf(img) + + expected = default_transforms["default"] + + torch.testing.assert_close(actual, expected) + + +@pytest.mark.parametrize("p", [[0, 1], [1, 0]]) +def test_random_subset_apply_single_choice(p, img): + flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] + random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False) + actual = random_choice(img) + + p_horz, _ = p + if p_horz: + torch.testing.assert_close(actual, F.horizontal_flip(img)) + else: + torch.testing.assert_close(actual, F.vertical_flip(img)) + + +def test_random_subset_apply_random_order(img): + flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] + random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True) + # We can't really check whether the transforms are actually applied in random order. However, + # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform + # applies them in random order, we can use a fixed order to compute the expected value. + actual = random_order(img) + expected = v2.Compose(flips)(img) + torch.testing.assert_close(actual, expected) + + +def test_random_subset_apply_valid_transforms(color_jitters, img): + transform = RandomSubsetApply(color_jitters) + output = transform(img) + assert output.shape == img.shape + + +def test_random_subset_apply_probability_length_mismatch(color_jitters): + with pytest.raises(ValueError): + RandomSubsetApply(color_jitters, p=[0.5, 0.5]) + + +@pytest.mark.parametrize("n_subset", [0, 5]) +def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset): + with pytest.raises(ValueError): + RandomSubsetApply(color_jitters, n_subset=n_subset) + + +def test_sharpness_jitter_valid_range_tuple(img): + tf = SharpnessJitter((0.1, 2.0)) + output = tf(img) + assert output.shape == img.shape + + +def test_sharpness_jitter_valid_range_float(img): + tf = SharpnessJitter(0.5) + output = tf(img) + assert output.shape == img.shape + + +def test_sharpness_jitter_invalid_range_min_negative(): + with pytest.raises(ValueError): + SharpnessJitter((-0.1, 2.0)) + + +def test_sharpness_jitter_invalid_range_max_smaller(): + with pytest.raises(ValueError): + SharpnessJitter((2.0, 0.1)) diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py new file mode 100644 index 00000000..7ddbe7aa --- /dev/null +++ b/tests/test_push_dataset_to_hub.py @@ -0,0 +1,352 @@ +""" +This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API. +Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets, +we skip them for now in our CI. + +Example to run backward compatiblity tests locally: +``` +DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility +``` +""" + +from pathlib import Path + +import numpy as np +import pytest +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently +from lerobot.common.datasets.video_utils import encode_video_frames +from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub +from tests.utils import require_package_arg + + +def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3): + import zarr + + raw_dir.mkdir(parents=True, exist_ok=True) + zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr" + store = zarr.DirectoryStore(zarr_path) + zarr_data = zarr.group(store=store) + + zarr_data.create_dataset( + "data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/img", + shape=(num_frames, 96, 96, 3), + chunks=(num_frames, 96, 96, 3), + dtype=np.uint8, + overwrite=True, + ) + zarr_data.create_dataset( + "data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True + ) + + zarr_data["data/action"][:] = np.random.randn(num_frames, 1) + zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) + zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2) + zarr_data["data/state"][:] = np.random.randn(num_frames, 5) + zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2) + zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) + + store.close() + + +def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3): + import zarr + + raw_dir.mkdir(parents=True, exist_ok=True) + zarr_path = raw_dir / "cup_in_the_wild.zarr" + store = zarr.DirectoryStore(zarr_path) + zarr_data = zarr.group(store=store) + + zarr_data.create_dataset( + "data/camera0_rgb", + shape=(num_frames, 96, 96, 3), + chunks=(num_frames, 96, 96, 3), + dtype=np.uint8, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_demo_end_pose", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_demo_start_pose", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/robot0_eef_rot_axis_angle", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_gripper_width", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True + ) + + zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) + zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5) + zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) + + store.close() + + +def _mock_download_raw_xarm(raw_dir, num_frames=4): + import pickle + + dataset_dict = { + "observations": { + "rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8), + "state": np.random.randn(num_frames, 4), + }, + "actions": np.random.randn(num_frames, 3), + "rewards": np.random.randn(num_frames), + "masks": np.random.randn(num_frames), + "dones": np.array([False, True, True, True]), + } + + raw_dir.mkdir(parents=True, exist_ok=True) + pkl_path = raw_dir / "buffer.pkl" + with open(pkl_path, "wb") as f: + pickle.dump(dataset_dict, f) + + +def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3): + import h5py + + for ep_idx in range(num_episodes): + raw_dir.mkdir(parents=True, exist_ok=True) + path_h5 = raw_dir / f"episode_{ep_idx}.hdf5" + with h5py.File(str(path_h5), "w") as f: + f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14)) + f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14)) + f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14)) + f.create_dataset( + "observations/images/top", + data=np.random.randint( + 0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8 + ), + ) + + +def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30): + from datetime import datetime, timedelta, timezone + + import pandas + + def write_parquet(key, timestamps, values): + data = { + "timestamp_utc": timestamps, + key: values, + } + df = pandas.DataFrame(data) + raw_dir.mkdir(parents=True, exist_ok=True) + df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow") + + episode_indices = [None, None, -1, None, None, -1, None, None, -1] + episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2] + frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1] + + cam_key = "observation.images.cam_high" + timestamps = [] + actions = [] + states = [] + frames = [] + # `+ num_episodes`` for buffer frames associated to episode_index=-1 + for i, frame_idx in enumerate(frame_indices): + t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps) + action = np.random.randn(21).tolist() + state = np.random.randn(21).tolist() + ep_idx = episode_indices_mapping[i] + frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}] + timestamps.append(t_utc) + actions.append(action) + states.append(state) + frames.append(frame) + + write_parquet(cam_key, timestamps, frames) + write_parquet("observation.state", timestamps, states) + write_parquet("action", timestamps, actions) + write_parquet("episode_index", timestamps, episode_indices) + + # write fake mp4 file for each episode + for ep_idx in range(num_episodes): + imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8) + + tmp_imgs_dir = raw_dir / "tmp_images" + save_images_concurrently(imgs_array, tmp_imgs_dir) + + fname = f"{cam_key}_episode_{ep_idx:06d}.mp4" + video_path = raw_dir / "videos" / fname + encode_video_frames(tmp_imgs_dir, video_path, fps) + + +def _mock_download_raw(raw_dir, repo_id): + if "wrist_gripper" in repo_id: + _mock_download_raw_dora(raw_dir) + elif "aloha" in repo_id: + _mock_download_raw_aloha(raw_dir) + elif "pusht" in repo_id: + _mock_download_raw_pusht(raw_dir) + elif "xarm" in repo_id: + _mock_download_raw_xarm(raw_dir) + elif "umi" in repo_id: + _mock_download_raw_umi(raw_dir) + else: + raise ValueError(repo_id) + + +def test_push_dataset_to_hub_invalid_repo_id(tmpdir): + with pytest.raises(ValueError): + push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id") + + +def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir): + tmpdir = Path(tmpdir) + out_dir = tmpdir / "out" + raw_dir = tmpdir / "raw" + # mkdir to skip download + raw_dir.mkdir(parents=True, exist_ok=True) + with pytest.raises(ValueError): + push_dataset_to_hub( + raw_dir=raw_dir, + raw_format="some_format", + repo_id="user/dataset", + local_dir=out_dir, + force_override=False, + ) + + +@pytest.mark.parametrize( + "required_packages, raw_format, repo_id", + [ + (["gym-pusht"], "pusht_zarr", "lerobot/pusht"), + (None, "xarm_pkl", "lerobot/xarm_lift_medium"), + (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), + (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"), + (None, "dora_parquet", "cadene/wrist_gripper"), + ], +) +@require_package_arg +def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id): + num_episodes = 3 + tmpdir = Path(tmpdir) + + raw_dir = tmpdir / f"{repo_id}_raw" + _mock_download_raw(raw_dir, repo_id) + + local_dir = tmpdir / repo_id + + lerobot_dataset = push_dataset_to_hub( + raw_dir=raw_dir, + raw_format=raw_format, + repo_id=repo_id, + push_to_hub=False, + local_dir=local_dir, + force_override=False, + cache_dir=tmpdir / "cache", + ) + + # minimal generic tests on the local directory containing LeRobotDataset + assert (local_dir / "meta_data" / "info.json").exists() + assert (local_dir / "meta_data" / "stats.safetensors").exists() + assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists() + for i in range(num_episodes): + for cam_key in lerobot_dataset.camera_keys: + assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists() + assert (local_dir / "train" / "dataset_info.json").exists() + assert (local_dir / "train" / "state.json").exists() + assert len(list((local_dir / "train").glob("*.arrow"))) > 0 + + # minimal generic tests on the item + item = lerobot_dataset[0] + assert "index" in item + assert "episode_index" in item + assert "timestamp" in item + for cam_key in lerobot_dataset.camera_keys: + assert cam_key in item + + +@pytest.mark.parametrize( + "raw_format, repo_id", + [ + # TODO(rcadene): add raw dataset test artifacts + ("pusht_zarr", "lerobot/pusht"), + ("xarm_pkl", "lerobot/xarm_lift_medium"), + ("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), + ("umi_zarr", "lerobot/umi_cup_in_the_wild"), + ("dora_parquet", "cadene/wrist_gripper"), + ], +) +@pytest.mark.skip( + "Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`" +) +def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id): + _, dataset_id = repo_id.split("/") + + tmpdir = Path(tmpdir) + raw_dir = tmpdir / f"{dataset_id}_raw" + local_dir = tmpdir / repo_id + + push_dataset_to_hub( + raw_dir=raw_dir, + raw_format=raw_format, + repo_id=repo_id, + push_to_hub=False, + local_dir=local_dir, + force_override=False, + cache_dir=tmpdir / "cache", + episodes=[0], + ) + + ds_actual = LeRobotDataset(repo_id, root=tmpdir) + ds_reference = LeRobotDataset(repo_id) + + assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset) + + def check_same_items(item1, item2): + assert item1.keys() == item2.keys(), "Keys mismatch" + + for key in item1: + if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor): + assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}" + else: + assert item1[key] == item2[key], f"Mismatch found in key: {key}" + + for i in range(len(ds_reference.hf_dataset)): + item_reference = ds_reference.hf_dataset[i] + item_actual = ds_actual.hf_dataset[i] + check_same_items(item_reference, item_actual) diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 71819568..33c5e80a 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + import pytest from lerobot.scripts.visualize_dataset import visualize_dataset @@ -30,3 +32,20 @@ def test_visualize_dataset(tmpdir, repo_id): serve=False, ) assert rrd_path.exists() + + +@pytest.mark.parametrize( + "repo_id", + ["lerobot/pusht"], +) +@pytest.mark.parametrize("root", [Path(__file__).parent / "data"]) +def test_visualize_local_dataset(tmpdir, repo_id, root): + rrd_path = visualize_dataset( + repo_id, + episode_index=0, + batch_size=32, + save=True, + output_dir=tmpdir, + root=root, + ) + assert rrd_path.exists() diff --git a/tests/utils.py b/tests/utils.py index ba49ee70..c1575656 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -76,6 +76,7 @@ def require_env(func): """ Decorator that skips the test if the required environment package is not installed. As it need 'env_name' in args, it also checks whether it is provided as an argument. + If 'env_name' is None, this check is skipped. """ @wraps(func) @@ -91,7 +92,7 @@ def require_env(func): # Perform the package check package_name = f"gym_{env_name}" - if not is_package_available(package_name): + if env_name is not None and not is_package_available(package_name): pytest.skip(f"gym-{env_name} not installed") return func(*args, **kwargs) @@ -99,6 +100,38 @@ def require_env(func): return wrapper +def require_package_arg(func): + """ + Decorator that skips the test if the required package is not installed. + This is similar to `require_env` but more general in that it can check any package (not just environments). + As it need 'required_packages' in args, it also checks whether it is provided as an argument. + If 'required_packages' is None, this check is skipped. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + # Determine if 'required_packages' is provided and extract its value + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + if "required_packages" in arg_names: + # Get the index of 'required_packages' and retrieve the value from args + index = arg_names.index("required_packages") + required_packages = args[index] if len(args) > index else kwargs.get("required_packages") + else: + raise ValueError("Function does not have 'required_packages' as an argument.") + + if required_packages is None: + return func(*args, **kwargs) + + # Perform the package check + for package in required_packages: + if not is_package_available(package): + pytest.skip(f"{package} not installed") + + return func(*args, **kwargs) + + return wrapper + + def require_package(package_name): """ Decorator that skips the test if the specified package is not installed.