forked from tangger/lerobot
Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3
This commit is contained in:
@@ -18,6 +18,7 @@ from lerobot.common.datasets.utils import (
|
||||
concat_video_files,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
safe_write_dataframe_to_parquet,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -100,7 +101,8 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [k for k, v in features.items() if v["dtype"] == "video"]
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
image_keys = [key for key in features if features[key]["dtype"] == "image"]
|
||||
|
||||
# Initialize output dataset metadata
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
@@ -203,9 +205,21 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
||||
file_idx,
|
||||
)
|
||||
|
||||
# Update the video index tracking
|
||||
video_idx["chunk_idx"] = chunk_idx
|
||||
video_idx["file_idx"] = file_idx
|
||||
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
|
||||
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
|
||||
)
|
||||
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
|
||||
)
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_path)
|
||||
else:
|
||||
# Update the existing parquet file with new rows
|
||||
aggr_df = pd.read_parquet(aggr_path)
|
||||
df = pd.concat([aggr_df, df], ignore_index=True)
|
||||
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
|
||||
|
||||
return videos_idx
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from datasets import Dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from torch.profiler import record_function
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
@@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import (
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
safe_write_dataframe_to_parquet,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if len(self.meta.image_keys) > 0:
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
|
||||
@@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
|
||||
if len(image_keys) > 0:
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
|
||||
@@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
4
tests/fixtures/constants.py
vendored
4
tests/fixtures/constants.py
vendored
@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
|
||||
44
tests/fixtures/dataset_factories.py
vendored
44
tests/fixtures/dataset_factories.py
vendored
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
@@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import (
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
@@ -95,7 +97,7 @@ def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
use_videos: bool = False,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
@@ -129,7 +131,7 @@ def info_factory(features_factory):
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
use_videos: bool = False,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
@@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory):
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_videos(info_factory, img_array_factory):
|
||||
def _create_video_directory(
|
||||
root: Path,
|
||||
info: dict | None = None,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
):
|
||||
if info is None:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
|
||||
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||
for key, ft in video_feats.items():
|
||||
# create and save images
|
||||
tmp_dir = root / "tmp_images"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
for frame_index in range(info["total_frames"]):
|
||||
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||
pil_img = PIL.Image.fromarray(img)
|
||||
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
||||
pil_img.save(path)
|
||||
|
||||
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
return _create_video_directory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
@@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
||||
img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
@@ -439,6 +473,7 @@ def lerobot_dataset_factory(
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
@@ -448,19 +483,18 @@ def lerobot_dataset_factory(
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if episodes_metadata is None:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
|
||||
|
||||
# Write data on disk
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
|
||||
42
tests/fixtures/hub.py
vendored
42
tests/fixtures/hub.py
vendored
@@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
)
|
||||
@@ -40,6 +41,7 @@ def mock_snapshot_download_factory(
|
||||
create_episodes,
|
||||
hf_dataset_factory,
|
||||
create_hf_dataset,
|
||||
create_videos,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
@@ -91,40 +93,48 @@ def mock_snapshot_download_factory(
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
|
||||
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
|
||||
for key in video_keys:
|
||||
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
has_info = False
|
||||
has_tasks = False
|
||||
has_episodes = False
|
||||
has_stats = False
|
||||
has_data = False
|
||||
request_info = False
|
||||
request_tasks = False
|
||||
request_episodes = False
|
||||
request_stats = False
|
||||
request_data = False
|
||||
request_videos = False
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("meta/info.json"):
|
||||
has_info = True
|
||||
request_info = True
|
||||
elif rel_path.startswith("meta/stats"):
|
||||
has_stats = True
|
||||
request_stats = True
|
||||
elif rel_path.startswith("meta/tasks"):
|
||||
has_tasks = True
|
||||
request_tasks = True
|
||||
elif rel_path.startswith("meta/episodes"):
|
||||
has_episodes = True
|
||||
request_episodes = True
|
||||
elif rel_path.startswith("data/"):
|
||||
has_data = True
|
||||
request_data = True
|
||||
elif rel_path.startswith("videos/"):
|
||||
request_videos = True
|
||||
else:
|
||||
raise ValueError(f"{rel_path} not supported.")
|
||||
|
||||
if has_info:
|
||||
if request_info:
|
||||
create_info(local_dir, info)
|
||||
if has_stats:
|
||||
if request_stats:
|
||||
create_stats(local_dir, stats)
|
||||
if has_tasks:
|
||||
if request_tasks:
|
||||
create_tasks(local_dir, tasks)
|
||||
if has_episodes:
|
||||
if request_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
# TODO(rcadene): create_videos?
|
||||
if has_data:
|
||||
if request_data:
|
||||
create_hf_dataset(local_dir, hf_dataset)
|
||||
if request_videos:
|
||||
create_videos(root=local_dir, info=info)
|
||||
|
||||
return str(local_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user