diff --git a/lerobot/scripts/show_image_transforms.py b/lerobot/scripts/show_image_transforms.py index cca618348..e0db3377f 100644 --- a/lerobot/scripts/show_image_transforms.py +++ b/lerobot/scripts/show_image_transforms.py @@ -1,27 +1,30 @@ from pathlib import Path +import hydra + from torchvision.transforms import ToPILImage from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.transforms import make_transforms -from lerobot.common.utils.utils import init_hydra_config -DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" to_pil = ToPILImage() -def main(repo_id): +def main(cfg, output_dir=Path("outputs/image_transforms")): """ - Apply a series of image transformations to a frame from a dataset and save the transformed images. + Function to apply image transforms from a configuration and save the transformed images. Args: - repo_id (str): The ID of the repository. + cfg (object): Configuration object containing the image transform settings and dataset_repo_id. + output_dir (str or Path, optional): Output directory to save the transformed images. Defaults to "outputs/image_transforms". + + Returns: + None """ - - transforms = ["colorjitter", "sharpness", "blur"] - - dataset = LeRobotDataset(repo_id, transform=None) - output_dir = Path("outputs/image_transforms") / Path(repo_id) + + dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None) + + output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1]) output_dir.mkdir(parents=True, exist_ok=True) # Get first frame of 1st episode @@ -33,25 +36,25 @@ def main(repo_id): # Apply each single transformation for transform_name in transforms: - overrides = [ - "image_transform.enable=True", - "image_transform.max_num_transforms=1", - ] + cfg.image_transform.enable=True + cfg.image_transform.max_num_transforms=1 + for t in transforms: if t == transform_name: - overrides.append(f"image_transform.{t}.weight=1") - overrides.append(f"image_transform.{t}_p=1") + cfg.image_transform[t].weight=1 else: - overrides.append(f"image_transform.{t}.weight=0") - cfg = init_hydra_config( - DEFAULT_CONFIG_PATH, - overrides=overrides, - ) + cfg.image_transform[t].weight=0 + transform = make_transforms(cfg.image_transform) img = transform(frame) to_pil(img).save(output_dir / f"{transform_name}.png", quality=100) +@hydra.main(version_base="1.2", config_name="default", config_path="../configs") +def visualize_transforms_cli(cfg: dict): + main( + cfg, + ) + if __name__ == "__main__": - repo_id = "cadene/reachy2_teleop_remi" - main(repo_id) + visualize_transforms_cli()