Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3

This commit is contained in:
Remi Cadene
2025-05-16 17:50:14 +00:00
7 changed files with 97 additions and 33 deletions

View File

@@ -18,6 +18,7 @@ from lerobot.common.datasets.utils import (
concat_video_files,
get_parquet_file_size_in_mb,
get_video_size_in_mb,
safe_write_dataframe_to_parquet,
update_chunk_file_indices,
write_info,
write_stats,
@@ -100,7 +101,8 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
]
)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [k for k, v in features.items() if v["dtype"] == "video"]
video_keys = [key for key in features if features[key]["dtype"] == "video"]
image_keys = [key for key in features if features[key]["dtype"] == "image"]
# Initialize output dataset metadata
dst_meta = LeRobotDatasetMetadata.create(
@@ -203,9 +205,21 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
file_idx,
)
# Update the video index tracking
video_idx["chunk_idx"] = chunk_idx
video_idx["file_idx"] = file_idx
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
# Size limit is reached, prepare new parquet file
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
)
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
)
aggr_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(aggr_path)
else:
# Update the existing parquet file with new rows
aggr_df = pd.read_parquet(aggr_path)
df = pd.concat([aggr_df, df], ignore_index=True)
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
return videos_idx

View File

@@ -31,7 +31,7 @@ from datasets import Dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from torch.profiler import record_function
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
@@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import (
load_nested_dataset,
load_stats,
load_tasks,
safe_write_dataframe_to_parquet,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
@@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(self.meta.image_keys) > 0:
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.

View File

@@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
if len(image_keys) > 0:
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)

View File

@@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())

View File

@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import shutil
from functools import partial
from pathlib import Path
from typing import Protocol
@@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@@ -95,7 +97,7 @@ def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
use_videos: bool = False,
) -> dict:
if use_videos:
camera_ft = {
@@ -129,7 +131,7 @@ def info_factory(features_factory):
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
use_videos: bool = False,
) -> dict:
features = features_factory(motor_features, camera_features, use_videos)
return {
@@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory):
return _create_episodes
@pytest.fixture(scope="session")
def create_videos(info_factory, img_array_factory):
def _create_video_directory(
root: Path,
info: dict | None = None,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
):
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
for key, ft in video_feats.items():
# create and save images
tmp_dir = root / "tmp_images"
tmp_dir.mkdir(parents=True, exist_ok=True)
for frame_index in range(info["total_frames"]):
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
pil_img = PIL.Image.fromarray(img)
path = tmp_dir / f"frame-{frame_index:06d}.png"
pil_img.save(path)
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
shutil.rmtree(tmp_dir)
return _create_video_directory
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
@@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
@@ -439,6 +473,7 @@ def lerobot_dataset_factory(
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
# Instantiate objects
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
@@ -448,19 +483,18 @@ def lerobot_dataset_factory(
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
# Write data on disk
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,

42
tests/fixtures/hub.py vendored
View File

@@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_TASKS_PATH,
DEFAULT_VIDEO_PATH,
INFO_PATH,
STATS_PATH,
)
@@ -40,6 +41,7 @@ def mock_snapshot_download_factory(
create_episodes,
hf_dataset_factory,
create_hf_dataset,
create_videos,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
@@ -91,40 +93,48 @@ def mock_snapshot_download_factory(
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
]
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
for key in video_keys:
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
has_info = False
has_tasks = False
has_episodes = False
has_stats = False
has_data = False
request_info = False
request_tasks = False
request_episodes = False
request_stats = False
request_data = False
request_videos = False
for rel_path in allowed_files:
if rel_path.startswith("meta/info.json"):
has_info = True
request_info = True
elif rel_path.startswith("meta/stats"):
has_stats = True
request_stats = True
elif rel_path.startswith("meta/tasks"):
has_tasks = True
request_tasks = True
elif rel_path.startswith("meta/episodes"):
has_episodes = True
request_episodes = True
elif rel_path.startswith("data/"):
has_data = True
request_data = True
elif rel_path.startswith("videos/"):
request_videos = True
else:
raise ValueError(f"{rel_path} not supported.")
if has_info:
if request_info:
create_info(local_dir, info)
if has_stats:
if request_stats:
create_stats(local_dir, stats)
if has_tasks:
if request_tasks:
create_tasks(local_dir, tasks)
if has_episodes:
if request_episodes:
create_episodes(local_dir, episodes)
# TODO(rcadene): create_videos?
if has_data:
if request_data:
create_hf_dataset(local_dir, hf_dataset)
if request_videos:
create_videos(root=local_dir, info=info)
return str(local_dir)