fix visualize_dataset
This commit is contained in:
@@ -22,11 +22,22 @@ def visualize_dataset_cli(cfg: dict):
|
||||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
# Expects images in [0, 255].
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
# Expects images in [0, 1].
|
||||
frame = frames[0]
|
||||
_, c, h, w = frame.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
|
||||
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
|
||||
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
|
||||
|
||||
# convert to channel last uint8 [0, 255]
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = (frames * 255).type(torch.uint8)
|
||||
imageio.mimsave(video_path, frames.numpy(), fps=fps)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
||||
Reference in New Issue
Block a user