fix visualize_dataset
This commit is contained in:
@@ -22,11 +22,22 @@ def visualize_dataset_cli(cfg: dict):
|
|||||||
|
|
||||||
|
|
||||||
def cat_and_write_video(video_path, frames, fps):
|
def cat_and_write_video(video_path, frames, fps):
|
||||||
# Expects images in [0, 255].
|
|
||||||
frames = torch.cat(frames)
|
frames = torch.cat(frames)
|
||||||
assert frames.dtype == torch.uint8
|
|
||||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
# Expects images in [0, 1].
|
||||||
imageio.mimsave(video_path, frames, fps=fps)
|
frame = frames[0]
|
||||||
|
_, c, h, w = frame.shape
|
||||||
|
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||||
|
|
||||||
|
# sanity check that images are float32 in range [0,1]
|
||||||
|
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
|
||||||
|
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
|
||||||
|
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
|
||||||
|
|
||||||
|
# convert to channel last uint8 [0, 255]
|
||||||
|
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||||
|
frames = (frames * 255).type(torch.uint8)
|
||||||
|
imageio.mimsave(video_path, frames.numpy(), fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset(cfg: dict, out_dir=None):
|
def visualize_dataset(cfg: dict, out_dir=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user