diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 226fdc1f..aad17a21 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -22,11 +22,22 @@ def visualize_dataset_cli(cfg: dict): def cat_and_write_video(video_path, frames, fps): - # Expects images in [0, 255]. frames = torch.cat(frames) - assert frames.dtype == torch.uint8 - frames = einops.rearrange(frames, "b c h w -> b h w c").numpy() - imageio.mimsave(video_path, frames, fps=fps) + + # Expects images in [0, 1]. + frame = frames[0] + _, c, h, w = frame.shape + assert c < h and c < w, f"expect channel first images, but instead {frame.shape}" + + # sanity check that images are float32 in range [0,1] + assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}" + assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}" + assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}" + + # convert to channel last uint8 [0, 255] + frames = einops.rearrange(frames, "b c h w -> b h w c") + frames = (frames * 255).type(torch.uint8) + imageio.mimsave(video_path, frames.numpy(), fps=fps) def visualize_dataset(cfg: dict, out_dir=None):