This commit is contained in:
Cadene
2024-03-19 13:41:49 +00:00
parent a346469a5a
commit 9cdc24bc0e
3 changed files with 152 additions and 104 deletions

51
test.py
View File

@@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12
NUM_ACTION_CHANNELS = 12
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
def count_frames(video_path):
try:
# Construct the ffprobe command to get the number of frames
@@ -272,7 +232,7 @@ if __name__ == "__main__":
if "cuvid" in k:
print(f" - {k}")
def create_replay_buffer(device):
def create_replay_buffer(device, format=None):
data_dir = Path("tmp/2024_03_17_data_video/pusht")
num_slices = 1
@@ -293,6 +253,7 @@ if __name__ == "__main__":
data_dir=data_dir,
device=device,
frame_rate=None,
format=format,
in_keys=[("observation", "frame")],
out_keys=[("observation", "frame", "data")],
),
@@ -324,8 +285,8 @@ if __name__ == "__main__":
print(time.monotonic() - start)
def test_plot(seed=1337):
rb_cuda = create_replay_buffer(device="cuda")
rb_cpu = create_replay_buffer(device="cuda")
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
n_rows = 2 # len(replay_buffer)
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
@@ -337,7 +298,7 @@ if __name__ == "__main__":
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
frames = batch_cpu["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True)
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1
axes[i][0].imshow(frames[0])
@@ -348,7 +309,7 @@ if __name__ == "__main__":
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
frames = batch_cuda["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True)
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1
axes[i][1].imshow(frames[0])