Add video decoding to LeRobotDataset (#92)

This commit is contained in:
Remi
2024-05-03 00:50:19 +02:00
committed by GitHub
parent c1668924ab
commit b2cda12f87
116 changed files with 1406 additions and 301 deletions

View File

@@ -15,10 +15,12 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
from lerobot.common.datasets.utils import (
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
compute_stats,
flatten_dict,
get_stats_einops_patterns,
)
from lerobot.common.datasets.utils import (
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
@@ -105,15 +107,15 @@ def test_compute_stats_on_xarm():
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=8,
num_workers=0,
batch_size=len(dataset),
shuffle=False,
)