Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3

This commit is contained in:
Remi Cadene
2025-05-16 17:50:14 +00:00
7 changed files with 97 additions and 33 deletions

View File

@@ -18,6 +18,7 @@ from lerobot.common.datasets.utils import (
concat_video_files, concat_video_files,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
get_video_size_in_mb, get_video_size_in_mb,
safe_write_dataframe_to_parquet,
update_chunk_file_indices, update_chunk_file_indices,
write_info, write_info,
write_stats, write_stats,
@@ -100,7 +101,8 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
] ]
) )
fps, robot_type, features = validate_all_metadata(all_metadata) fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [k for k, v in features.items() if v["dtype"] == "video"] video_keys = [key for key in features if features[key]["dtype"] == "video"]
image_keys = [key for key in features if features[key]["dtype"] == "image"]
# Initialize output dataset metadata # Initialize output dataset metadata
dst_meta = LeRobotDatasetMetadata.create( dst_meta = LeRobotDatasetMetadata.create(
@@ -203,9 +205,21 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
file_idx, file_idx,
) )
# Update the video index tracking if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
video_idx["chunk_idx"] = chunk_idx # Size limit is reached, prepare new parquet file
video_idx["file_idx"] = file_idx aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
)
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
)
aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path)
else:
# Update the existing parquet file with new rows
aggr_df = pd.read_parquet(aggr_path)
df = pd.concat([aggr_df, df], ignore_index=True)
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
return videos_idx return videos_idx

View File

@@ -31,7 +31,7 @@ from datasets import Dataset
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
from torch.profiler import record_function
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
@@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import (
load_nested_dataset, load_nested_dataset,
load_stats, load_stats,
load_tasks, load_tasks,
safe_write_dataframe_to_parquet,
update_chunk_file_indices, update_chunk_file_indices,
validate_episode_buffer, validate_episode_buffer,
validate_frame, validate_frame,
@@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Write the resulting dataframe from RAM to disk # Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
if len(self.meta.image_keys) > 0: safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)
# Update the Hugging Face dataset by reloading it. # Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified. # This process should be fast because only the latest Parquet file has been modified.

View File

@@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}" f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}" f"In features not in episode_buffer: {set(features) - buffer_keys}"
) )
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
if len(image_keys) > 0:
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)

View File

@@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init" root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init) dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
init_attr = set(vars(dataset_init).keys()) init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys()) create_attr = set(vars(dataset_create).keys())

View File

@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
}, },
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
} }
DEFAULT_FPS = 30 DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = { DUMMY_VIDEO_INFO = {

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
import shutil
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Protocol from typing import Protocol
@@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features, get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import ( from tests.fixtures.constants import (
DEFAULT_FPS, DEFAULT_FPS,
DUMMY_CAMERA_FEATURES, DUMMY_CAMERA_FEATURES,
@@ -95,7 +97,7 @@ def features_factory():
def _create_features( def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True, use_videos: bool = False,
) -> dict: ) -> dict:
if use_videos: if use_videos:
camera_ft = { camera_ft = {
@@ -129,7 +131,7 @@ def info_factory(features_factory):
video_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True, use_videos: bool = False,
) -> dict: ) -> dict:
features = features_factory(motor_features, camera_features, use_videos) features = features_factory(motor_features, camera_features, use_videos)
return { return {
@@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory):
return _create_episodes return _create_episodes
@pytest.fixture(scope="session")
def create_videos(info_factory, img_array_factory):
def _create_video_directory(
root: Path,
info: dict | None = None,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
):
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
for key, ft in video_feats.items():
# create and save images
tmp_dir = root / "tmp_images"
tmp_dir.mkdir(parents=True, exist_ok=True)
for frame_index in range(info["total_frames"]):
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
pil_img = PIL.Image.fromarray(img)
path = tmp_dir / f"frame-{frame_index:06d}.png"
pil_img.save(path)
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
shutil.rmtree(tmp_dir)
return _create_video_directory
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset( def _create_hf_dataset(
@@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for key, ft in features.items(): for key, ft in features.items():
if ft["dtype"] == "image": if ft["dtype"] == "image":
robot_cols[key] = [ robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
for _ in range(len(index_col)) for _ in range(len(index_col))
] ]
elif ft["shape"][0] > 1 and ft["dtype"] != "video": elif ft["shape"][0] > 1 and ft["dtype"] != "video":
@@ -439,6 +473,7 @@ def lerobot_dataset_factory(
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
**kwargs, **kwargs,
) -> LeRobotDataset: ) -> LeRobotDataset:
# Instantiate objects
if info is None: if info is None:
info = info_factory( info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
@@ -448,19 +483,18 @@ def lerobot_dataset_factory(
if tasks is None: if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if episodes_metadata is None: if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory( episodes_metadata = episodes_factory(
features=info["features"], features=info["features"],
fps=info["fps"], fps=info["fps"],
total_episodes=info["total_episodes"], total_episodes=info["total_episodes"],
total_frames=info["total_frames"], total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks, tasks=tasks,
multi_task=multi_task, multi_task=multi_task,
) )
if not hf_dataset: if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"]) hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
# Write data on disk
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info=info, info=info,
stats=stats, stats=stats,

42
tests/fixtures/hub.py vendored
View File

@@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
DEFAULT_TASKS_PATH, DEFAULT_TASKS_PATH,
DEFAULT_VIDEO_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH, STATS_PATH,
) )
@@ -40,6 +41,7 @@ def mock_snapshot_download_factory(
create_episodes, create_episodes,
hf_dataset_factory, hf_dataset_factory,
create_hf_dataset, create_hf_dataset,
create_videos,
): ):
""" """
This factory allows to patch snapshot_download such that when called, it will create expected files rather This factory allows to patch snapshot_download such that when called, it will create expected files rather
@@ -91,40 +93,48 @@ def mock_snapshot_download_factory(
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
] ]
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
for key in video_keys:
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
allowed_files = filter_repo_objects( allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
) )
has_info = False request_info = False
has_tasks = False request_tasks = False
has_episodes = False request_episodes = False
has_stats = False request_stats = False
has_data = False request_data = False
request_videos = False
for rel_path in allowed_files: for rel_path in allowed_files:
if rel_path.startswith("meta/info.json"): if rel_path.startswith("meta/info.json"):
has_info = True request_info = True
elif rel_path.startswith("meta/stats"): elif rel_path.startswith("meta/stats"):
has_stats = True request_stats = True
elif rel_path.startswith("meta/tasks"): elif rel_path.startswith("meta/tasks"):
has_tasks = True request_tasks = True
elif rel_path.startswith("meta/episodes"): elif rel_path.startswith("meta/episodes"):
has_episodes = True request_episodes = True
elif rel_path.startswith("data/"): elif rel_path.startswith("data/"):
has_data = True request_data = True
elif rel_path.startswith("videos/"):
request_videos = True
else: else:
raise ValueError(f"{rel_path} not supported.") raise ValueError(f"{rel_path} not supported.")
if has_info: if request_info:
create_info(local_dir, info) create_info(local_dir, info)
if has_stats: if request_stats:
create_stats(local_dir, stats) create_stats(local_dir, stats)
if has_tasks: if request_tasks:
create_tasks(local_dir, tasks) create_tasks(local_dir, tasks)
if has_episodes: if request_episodes:
create_episodes(local_dir, episodes) create_episodes(local_dir, episodes)
# TODO(rcadene): create_videos? if request_data:
if has_data:
create_hf_dataset(local_dir, hf_dataset) create_hf_dataset(local_dir, hf_dataset)
if request_videos:
create_videos(root=local_dir, info=info)
return str(local_dir) return str(local_dir)