forked from tangger/lerobot
Implented visualize_image_transforms script
This commit is contained in:
@@ -1,27 +1,30 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
|
||||||
from torchvision.transforms import ToPILImage
|
from torchvision.transforms import ToPILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.transforms import make_transforms
|
from lerobot.common.datasets.transforms import make_transforms
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
|
||||||
to_pil = ToPILImage()
|
to_pil = ToPILImage()
|
||||||
|
|
||||||
|
|
||||||
def main(repo_id):
|
def main(cfg, output_dir=Path("outputs/image_transforms")):
|
||||||
"""
|
"""
|
||||||
Apply a series of image transformations to a frame from a dataset and save the transformed images.
|
Function to apply image transforms from a configuration and save the transformed images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id (str): The ID of the repository.
|
cfg (object): Configuration object containing the image transform settings and dataset_repo_id.
|
||||||
|
output_dir (str or Path, optional): Output directory to save the transformed images. Defaults to "outputs/image_transforms".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
transforms = ["colorjitter", "sharpness", "blur"]
|
dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, transform=None)
|
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
|
||||||
output_dir = Path("outputs/image_transforms") / Path(repo_id)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Get first frame of 1st episode
|
# Get first frame of 1st episode
|
||||||
@@ -33,25 +36,25 @@ def main(repo_id):
|
|||||||
|
|
||||||
# Apply each single transformation
|
# Apply each single transformation
|
||||||
for transform_name in transforms:
|
for transform_name in transforms:
|
||||||
overrides = [
|
cfg.image_transform.enable=True
|
||||||
"image_transform.enable=True",
|
cfg.image_transform.max_num_transforms=1
|
||||||
"image_transform.max_num_transforms=1",
|
|
||||||
]
|
|
||||||
for t in transforms:
|
for t in transforms:
|
||||||
if t == transform_name:
|
if t == transform_name:
|
||||||
overrides.append(f"image_transform.{t}.weight=1")
|
cfg.image_transform[t].weight=1
|
||||||
overrides.append(f"image_transform.{t}_p=1")
|
|
||||||
else:
|
else:
|
||||||
overrides.append(f"image_transform.{t}.weight=0")
|
cfg.image_transform[t].weight=0
|
||||||
cfg = init_hydra_config(
|
|
||||||
DEFAULT_CONFIG_PATH,
|
|
||||||
overrides=overrides,
|
|
||||||
)
|
|
||||||
transform = make_transforms(cfg.image_transform)
|
transform = make_transforms(cfg.image_transform)
|
||||||
img = transform(frame)
|
img = transform(frame)
|
||||||
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
|
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
def visualize_transforms_cli(cfg: dict):
|
||||||
|
main(
|
||||||
|
cfg,
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
repo_id = "cadene/reachy2_teleop_remi"
|
visualize_transforms_cli()
|
||||||
main(repo_id)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user