Add video decoding to LeRobotDataset (#92)

This commit is contained in:
Remi
2024-05-03 00:50:19 +02:00
committed by GitHub
parent c1668924ab
commit b2cda12f87
116 changed files with 1406 additions and 301 deletions

View File

@@ -16,9 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
# TODO(rcadene): enable for PR video dataset
# from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir) -> bool:
@@ -79,14 +77,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
video_path = out_dir / "videos" / f"{img_key}_episode_{ep_idx:06d}.mp4"
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the episode idx
ep_dict[img_key] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@@ -122,7 +123,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
image_keys = [key for key in data_dict if "observation.images." in key]
for image_key in image_keys:
if video:
features[image_key] = Value(dtype="int64", id="video")
features[image_key] = VideoFrame()
else:
features[image_key] = Image()

View File

@@ -0,0 +1,146 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are in channel first format
"""
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, (VideoFrame, Image)):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
):
if max_num_samples is None:
max_num_samples = len(dataset)
# for more info on why we need to set the same number of workers, see `load_from_videos`
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats

View File

@@ -14,9 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
# TODO(rcadene): enable for PR video dataset
# from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir):
@@ -127,26 +125,28 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict = {}
imgs_array = [x.numpy() for x in image]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the episode index
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = agent_pos
ep_dict["action"] = actions[id_from:id_to]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = image[1:],
@@ -174,7 +174,7 @@ def to_hf_dataset(data_dict, video):
features = {}
if video:
features["observation.image"] = Value(dtype="int64", id="video")
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()

View File

@@ -16,9 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
# TODO(rcadene): enable for PR video dataset
# from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir) -> bool:
@@ -103,25 +101,27 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
# load 57MB of images in RAM (400x224x224x3 uint8)
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the episode index
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
@@ -151,7 +151,7 @@ def to_hf_dataset(data_dict, video):
features = {}
if video:
features["observation.image"] = Value(dtype="int64", id="video")
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()

View File

@@ -14,9 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
# TODO(rcadene): enable for PR video dataset
# from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir):
@@ -76,26 +74,28 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
ep_dict = {}
imgs_array = [x.numpy() for x in image]
img_key = "observation.image"
if video:
# save png images in temporary directory
tmp_imgs_dir = out_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
encode_video_frames(tmp_imgs_dir, video_path, fps) # noqa: F821
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the episode index
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = next_image
@@ -122,7 +122,7 @@ def to_hf_dataset(data_dict, video):
features = {}
if video:
features["observation.image"] = Value(dtype="int64", id="video")
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()