WIP
This commit is contained in:
51
test.py
51
test.py
@@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12
|
||||
NUM_ACTION_CHANNELS = 12
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.13983 * v
|
||||
g = y + -0.39465 * u - 0.58060 * v
|
||||
b = y + 2.03211 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
frames = frames.cpu()
|
||||
import cv2
|
||||
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = frames.numpy()
|
||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||
frames = [torch.from_numpy(frame) for frame in frames]
|
||||
frames = torch.stack(frames)
|
||||
if not return_hwc:
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
return frames
|
||||
|
||||
|
||||
def count_frames(video_path):
|
||||
try:
|
||||
# Construct the ffprobe command to get the number of frames
|
||||
@@ -272,7 +232,7 @@ if __name__ == "__main__":
|
||||
if "cuvid" in k:
|
||||
print(f" - {k}")
|
||||
|
||||
def create_replay_buffer(device):
|
||||
def create_replay_buffer(device, format=None):
|
||||
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
||||
|
||||
num_slices = 1
|
||||
@@ -293,6 +253,7 @@ if __name__ == "__main__":
|
||||
data_dir=data_dir,
|
||||
device=device,
|
||||
frame_rate=None,
|
||||
format=format,
|
||||
in_keys=[("observation", "frame")],
|
||||
out_keys=[("observation", "frame", "data")],
|
||||
),
|
||||
@@ -324,8 +285,8 @@ if __name__ == "__main__":
|
||||
print(time.monotonic() - start)
|
||||
|
||||
def test_plot(seed=1337):
|
||||
rb_cuda = create_replay_buffer(device="cuda")
|
||||
rb_cpu = create_replay_buffer(device="cuda")
|
||||
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
|
||||
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
|
||||
|
||||
n_rows = 2 # len(replay_buffer)
|
||||
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
||||
@@ -337,7 +298,7 @@ if __name__ == "__main__":
|
||||
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cpu["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][0].imshow(frames[0])
|
||||
|
||||
@@ -348,7 +309,7 @@ if __name__ == "__main__":
|
||||
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cuda["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][1].imshow(frames[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user