added visualization for min and max transforms (#271)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a92d79fff2
commit
e28fa2344c
@@ -65,11 +65,10 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
OUTPUT_DIR = Path("outputs/image_transforms")
|
||||
N_EXAMPLES = 5
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
|
||||
tf = get_image_transforms(
|
||||
brightness_weight=cfg.brightness.weight,
|
||||
brightness_min_max=cfg.brightness.min_max,
|
||||
@@ -88,7 +87,7 @@ def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||
output_dir_all = output_dir / "all"
|
||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(1, N_EXAMPLES + 1):
|
||||
for i in range(1, n_examples + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
|
||||
|
||||
@@ -96,7 +95,7 @@ def save_config_all_transforms(cfg, original_frame, output_dir):
|
||||
print(f" {output_dir_all}")
|
||||
|
||||
|
||||
def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
|
||||
transforms = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
@@ -106,6 +105,7 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||
]
|
||||
print("Individual transforms examples saved to:")
|
||||
for transform in transforms:
|
||||
# Apply one transformation with random value in min_max range
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": cfg[f"{transform}"].min_max,
|
||||
@@ -114,18 +114,46 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
|
||||
output_dir_single = output_dir / f"{transform}"
|
||||
output_dir_single.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(1, N_EXAMPLES + 1):
|
||||
for i in range(1, n_examples + 1):
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
|
||||
|
||||
# Apply min transformation
|
||||
min_value, max_value = cfg[f"{transform}"].min_max
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (min_value, min_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100)
|
||||
|
||||
# Apply max transformation
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (max_value, max_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100)
|
||||
|
||||
# Apply mean transformation
|
||||
mean_value = (min_value + max_value) / 2
|
||||
kwargs = {
|
||||
f"{transform}_weight": cfg[f"{transform}"].weight,
|
||||
f"{transform}_min_max": (mean_value, mean_value),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
transformed_frame = tf(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)
|
||||
|
||||
print(f" {output_dir_single}")
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_transforms(cfg):
|
||||
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
|
||||
output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
|
||||
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get 1st frame from 1st camera of 1st episode
|
||||
@@ -134,8 +162,13 @@ def visualize_transforms(cfg):
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
|
||||
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_transforms_cli(cfg):
|
||||
visualize_transforms(cfg, output_dir=OUTPUT_DIR)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user