Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
@@ -19,29 +19,21 @@ import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
with seeded_context(1337):
|
||||
img_tf = default_tf(original_frame)
|
||||
@@ -51,29 +43,26 @@ def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path
|
||||
|
||||
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
transforms = {
|
||||
"brightness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"contrast": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"saturation": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"hue": [(-0.25, -0.25), (0.25, 0.25)],
|
||||
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
}
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for transform, values in transforms.items():
|
||||
for min_max in values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
for tf_type, tf_name, min_max_values in transforms.items():
|
||||
for min_max in min_max_values:
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
||||
frames[key] = tf(original_frame)
|
||||
|
||||
save_file(frames, output_dir / "single_transforms.safetensors")
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, episodes=[0], image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
|
||||
Reference in New Issue
Block a user