From 85f1554da885212bfc6cd4ee9a43fdaf075782ae Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 19 Apr 2024 23:40:35 +0000 Subject: [PATCH] fix visualize_dataset --- lerobot/scripts/visualize_dataset.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 226fdc1fa..aad17a21f 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -22,11 +22,22 @@ def visualize_dataset_cli(cfg: dict): def cat_and_write_video(video_path, frames, fps): - # Expects images in [0, 255]. frames = torch.cat(frames) - assert frames.dtype == torch.uint8 - frames = einops.rearrange(frames, "b c h w -> b h w c").numpy() - imageio.mimsave(video_path, frames, fps=fps) + + # Expects images in [0, 1]. + frame = frames[0] + _, c, h, w = frame.shape + assert c < h and c < w, f"expect channel first images, but instead {frame.shape}" + + # sanity check that images are float32 in range [0,1] + assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}" + assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}" + assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}" + + # convert to channel last uint8 [0, 255] + frames = einops.rearrange(frames, "b c h w -> b h w c") + frames = (frames * 255).type(torch.uint8) + imageio.mimsave(video_path, frames.numpy(), fps=fps) def visualize_dataset(cfg: dict, out_dir=None):