This commit is contained in:
Simon Alibert
2024-06-06 15:23:49 +00:00
parent bdc0ebd36a
commit a86f387554
3 changed files with 102 additions and 8 deletions

View File

@@ -1,6 +1,8 @@
from pathlib import Path
from torchvision.transforms import ToPILImage, v2
import torch
from torchvision.transforms import v2
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RangeRandomSharpness
@@ -9,7 +11,7 @@ from lerobot.common.utils.utils import seeded_context
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
ARTIFACT_DIR = "tests/data/save_image_transforms"
SEED = 1336
to_pil = ToPILImage()
to_pil = v2.ToPILImage()
def main(repo_id):
@@ -30,12 +32,15 @@ def main(repo_id):
"sharpness": RangeRandomSharpness(0.0, 2.0),
}
# Apply each single transformation
# frames = {"original_frame": original_frame}
for name, transform in transforms.items():
with seeded_context(SEED):
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
transformed_frame = transform(original_frame)
# frames[name] = transform(original_frame)
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
if __name__ == "__main__":
repo_id = "lerobot/aloha_mobile_shrimp"