Implented visualize_image_transforms script

This commit is contained in:
Marina Barannikov
2024-06-05 13:30:32 +00:00
parent 4dbc1adb0d
commit 0fb3dd745b

View File

@@ -1,27 +1,30 @@
from pathlib import Path from pathlib import Path
import hydra
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_transforms from lerobot.common.datasets.transforms import make_transforms
from lerobot.common.utils.utils import init_hydra_config
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
to_pil = ToPILImage() to_pil = ToPILImage()
def main(repo_id): def main(cfg, output_dir=Path("outputs/image_transforms")):
""" """
Apply a series of image transformations to a frame from a dataset and save the transformed images. Function to apply image transforms from a configuration and save the transformed images.
Args: Args:
repo_id (str): The ID of the repository. cfg (object): Configuration object containing the image transform settings and dataset_repo_id.
output_dir (str or Path, optional): Output directory to save the transformed images. Defaults to "outputs/image_transforms".
Returns:
None
""" """
transforms = ["colorjitter", "sharpness", "blur"] dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)
dataset = LeRobotDataset(repo_id, transform=None) output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
output_dir = Path("outputs/image_transforms") / Path(repo_id)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Get first frame of 1st episode # Get first frame of 1st episode
@@ -33,25 +36,25 @@ def main(repo_id):
# Apply each single transformation # Apply each single transformation
for transform_name in transforms: for transform_name in transforms:
overrides = [ cfg.image_transform.enable=True
"image_transform.enable=True", cfg.image_transform.max_num_transforms=1
"image_transform.max_num_transforms=1",
]
for t in transforms: for t in transforms:
if t == transform_name: if t == transform_name:
overrides.append(f"image_transform.{t}.weight=1") cfg.image_transform[t].weight=1
overrides.append(f"image_transform.{t}_p=1")
else: else:
overrides.append(f"image_transform.{t}.weight=0") cfg.image_transform[t].weight=0
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=overrides,
)
transform = make_transforms(cfg.image_transform) transform = make_transforms(cfg.image_transform)
img = transform(frame) img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100) to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms_cli(cfg: dict):
main(
cfg,
)
if __name__ == "__main__": if __name__ == "__main__":
repo_id = "cadene/reachy2_teleop_remi" visualize_transforms_cli()
main(repo_id)