Add video decoding to LeRobotDataset (#92)

This commit is contained in:
Remi
2024-05-03 00:50:19 +02:00
committed by GitHub
parent c1668924ab
commit b2cda12f87
116 changed files with 1406 additions and 301 deletions

View File

@@ -1,14 +1,10 @@
import json
from copy import deepcopy
from math import ceil
from pathlib import Path
import datasets
import einops
import torch
import tqdm
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
@@ -57,6 +53,9 @@ def hf_transform_to_torch(items_dict):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
@@ -127,17 +126,29 @@ def load_info(repo_id, version, root) -> dict:
return info
def load_videos(repo_id, version, root) -> Path:
if root is not None:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version)
path = Path(repo_dir) / "videos"
return path
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float,
tolerance_s: float,
) -> dict[torch.Tensor]:
"""
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
@@ -156,7 +167,7 @@ def load_previous_and_future_frames(
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
@@ -194,11 +205,11 @@ def load_previous_and_future_frames(
# TODO(rcadene): synchronize timestamps + interpolation if needed
is_pad = min_ > tol
is_pad = min_ > tolerance_s
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
@@ -207,144 +218,18 @@ def load_previous_and_future_frames(
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[key] = torch.stack(item[key])
if isinstance(item[key][0], dict) and "path" in item[key][0]:
# video mode where frame are expressed as dict of path and timestamp
item[key] = item[key]
else:
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images of `hf_dataset` are in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(hf_dataset)
stats_patterns = get_stats_einops_patterns(hf_dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats
def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.