Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
@@ -18,142 +18,102 @@
|
||||
This script will generate examples of transformed images as they are output by LeRobot dataset.
|
||||
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
|
||||
|
||||
|
||||
--- Usage Examples ---
|
||||
|
||||
Increase hue jitter
|
||||
```
|
||||
Example:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.hue.min_max="[-0.25,0.25]"
|
||||
--repo_id=lerobot/pusht \
|
||||
--episodes='[0]' \
|
||||
--image_transforms.enable=True
|
||||
```
|
||||
|
||||
Increase brightness & brightness weight
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.brightness.weight=10.0 \
|
||||
training.image_transforms.brightness.min_max="[1.0,2.0]"
|
||||
```
|
||||
|
||||
Blur images and disable saturation & hue
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.sharpness.weight=10.0 \
|
||||
training.image_transforms.sharpness.min_max="[0.0,1.0]" \
|
||||
training.image_transforms.saturation.weight=0.0 \
|
||||
training.image_transforms.hue.weight=0.0
|
||||
```
|
||||
|
||||
Use all transforms with random order
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.max_num_transforms=5 \
|
||||
training.image_transforms.random_order=true
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import draccus
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
|
||||
OUTPUT_DIR = Path("outputs/image_transforms")
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
|
||||
tf = get_image_transforms(
|
||||
brightness_weight=cfg.brightness.weight,
|
||||
brightness_min_max=cfg.brightness.min_max,
|
||||
contrast_weight=cfg.contrast.weight,
|
||||
contrast_min_max=cfg.contrast.min_max,
|
||||
saturation_weight=cfg.saturation.weight,
|
||||
saturation_min_max=cfg.saturation.min_max,
|
||||
hue_weight=cfg.hue.weight,
|
||||
hue_min_max=cfg.hue.min_max,
|
||||
sharpness_weight=cfg.sharpness.weight,
|
||||
sharpness_min_max=cfg.sharpness.min_max,
|
||||
max_num_transforms=cfg.max_num_transforms,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
output_dir_all = output_dir / "all"
|
||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tfs = ImageTransforms(cfg)
|
||||
for i in range(1, n_examples + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
transformed_frame = tfs(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
||||
|
||||
print("Combined transforms examples saved to:")
|
||||
print(f" {output_dir_all}")
|
||||
|
||||
|
||||
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
|
||||
transforms = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
"saturation",
|
||||
"hue",
|
||||
"sharpness",
|
||||
]
|
||||
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
if not cfg.enable:
|
||||
logging.warning(
|
||||
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
|
||||
)
|
||||
return
|
||||
|
||||
print("Individual transforms examples saved to:")
|
||||
for transform in transforms:
|
||||
# Apply one transformation with random value in min_max range
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
output_dir_single = output_dir / f"{transform}"
|
||||
for tf_name, tf_cfg in cfg.tfs.items():
|
||||
# Apply a few transformation with random value in min_max range
|
||||
output_dir_single = output_dir / tf_name
|
||||
output_dir_single.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
for i in range(1, n_examples + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
||||
|
||||
# Apply min transformation
|
||||
min_value, max_value = cfg[f"{transform}"].min_max
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (min_value, min_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100)
|
||||
# Apply min, max, average transformations
|
||||
tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs)
|
||||
tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs)
|
||||
tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs)
|
||||
|
||||
# Apply max transformation
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (max_value, max_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100)
|
||||
for key, (min_, max_) in tf_cfg.kwargs.items():
|
||||
avg = (min_ + max_) / 2
|
||||
tf_cfg_kwgs_min[key] = [min_, min_]
|
||||
tf_cfg_kwgs_max[key] = [max_, max_]
|
||||
tf_cfg_kwgs_avg[key] = [avg, avg]
|
||||
|
||||
# Apply mean transformation
|
||||
mean_value = (min_value + max_value) / 2
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (mean_value, mean_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)
|
||||
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
|
||||
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
|
||||
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
|
||||
|
||||
tf_frame_min = tf_min(original_frame)
|
||||
tf_frame_max = tf_max(original_frame)
|
||||
tf_frame_avg = tf_avg(original_frame)
|
||||
|
||||
to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100)
|
||||
to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100)
|
||||
to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100)
|
||||
|
||||
print(f" {output_dir_single}")
|
||||
|
||||
|
||||
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
@draccus.wrap()
|
||||
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
episodes=cfg.episodes,
|
||||
local_files_only=cfg.local_files_only,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
|
||||
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
|
||||
output_dir = output_dir / cfg.repo_id.split("/")[-1]
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get 1st frame from 1st camera of 1st episode
|
||||
@@ -162,14 +122,9 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_transforms_cli(cfg):
|
||||
visualize_transforms(cfg, output_dir=OUTPUT_DIR)
|
||||
save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples)
|
||||
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_transforms_cli()
|
||||
visualize_image_transforms()
|
||||
|
||||
Reference in New Issue
Block a user