Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3
This commit is contained in:
@@ -18,6 +18,7 @@ from lerobot.common.datasets.utils import (
|
|||||||
concat_video_files,
|
concat_video_files,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
get_video_size_in_mb,
|
get_video_size_in_mb,
|
||||||
|
safe_write_dataframe_to_parquet,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
@@ -100,7 +101,8 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
video_keys = [k for k, v in features.items() if v["dtype"] == "video"]
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
image_keys = [key for key in features if features[key]["dtype"] == "image"]
|
||||||
|
|
||||||
# Initialize output dataset metadata
|
# Initialize output dataset metadata
|
||||||
dst_meta = LeRobotDatasetMetadata.create(
|
dst_meta = LeRobotDatasetMetadata.create(
|
||||||
@@ -203,9 +205,21 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
|||||||
file_idx,
|
file_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the video index tracking
|
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
video_idx["chunk_idx"] = chunk_idx
|
# Size limit is reached, prepare new parquet file
|
||||||
video_idx["file_idx"] = file_idx
|
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
|
||||||
|
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||||
|
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
|
||||||
|
)
|
||||||
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(aggr_path)
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new rows
|
||||||
|
aggr_df = pd.read_parquet(aggr_path)
|
||||||
|
df = pd.concat([aggr_df, df], ignore_index=True)
|
||||||
|
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
|
||||||
|
|
||||||
return videos_idx
|
return videos_idx
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from datasets import Dataset
|
|||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.constants import REPOCARD_NAME
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
from torch.profiler import record_function
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
@@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import (
|
|||||||
load_nested_dataset,
|
load_nested_dataset,
|
||||||
load_stats,
|
load_stats,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
|
safe_write_dataframe_to_parquet,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
validate_episode_buffer,
|
validate_episode_buffer,
|
||||||
validate_frame,
|
validate_frame,
|
||||||
@@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
# Write the resulting dataframe from RAM to disk
|
# Write the resulting dataframe from RAM to disk
|
||||||
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if len(self.meta.image_keys) > 0:
|
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
|
||||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
|
||||||
else:
|
|
||||||
df.to_parquet(path)
|
|
||||||
|
|
||||||
# Update the Hugging Face dataset by reloading it.
|
# Update the Hugging Face dataset by reloading it.
|
||||||
# This process should be fast because only the latest Parquet file has been modified.
|
# This process should be fast because only the latest Parquet file has been modified.
|
||||||
|
|||||||
@@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
|||||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
|
||||||
|
if len(image_keys) > 0:
|
||||||
|
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||||
|
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||||
|
else:
|
||||||
|
df.to_parquet(path)
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||||
|
|
||||||
root_init = tmp_path / "init"
|
root_init = tmp_path / "init"
|
||||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||||
|
|
||||||
init_attr = set(vars(dataset_init).keys())
|
init_attr = set(vars(dataset_init).keys())
|
||||||
create_attr = set(vars(dataset_create).keys())
|
create_attr = set(vars(dataset_create).keys())
|
||||||
|
|||||||
4
tests/fixtures/constants.py
vendored
4
tests/fixtures/constants.py
vendored
@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
DUMMY_CAMERA_FEATURES = {
|
DUMMY_CAMERA_FEATURES = {
|
||||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
}
|
}
|
||||||
DEFAULT_FPS = 30
|
DEFAULT_FPS = 30
|
||||||
DUMMY_VIDEO_INFO = {
|
DUMMY_VIDEO_INFO = {
|
||||||
|
|||||||
44
tests/fixtures/dataset_factories.py
vendored
44
tests/fixtures/dataset_factories.py
vendored
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
@@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import (
|
|||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
DUMMY_CAMERA_FEATURES,
|
DUMMY_CAMERA_FEATURES,
|
||||||
@@ -95,7 +97,7 @@ def features_factory():
|
|||||||
def _create_features(
|
def _create_features(
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
use_videos: bool = True,
|
use_videos: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if use_videos:
|
if use_videos:
|
||||||
camera_ft = {
|
camera_ft = {
|
||||||
@@ -129,7 +131,7 @@ def info_factory(features_factory):
|
|||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
use_videos: bool = True,
|
use_videos: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
features = features_factory(motor_features, camera_features, use_videos)
|
features = features_factory(motor_features, camera_features, use_videos)
|
||||||
return {
|
return {
|
||||||
@@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory):
|
|||||||
return _create_episodes
|
return _create_episodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def create_videos(info_factory, img_array_factory):
|
||||||
|
def _create_video_directory(
|
||||||
|
root: Path,
|
||||||
|
info: dict | None = None,
|
||||||
|
total_episodes: int = 3,
|
||||||
|
total_frames: int = 150,
|
||||||
|
total_tasks: int = 1,
|
||||||
|
):
|
||||||
|
if info is None:
|
||||||
|
info = info_factory(
|
||||||
|
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||||
|
)
|
||||||
|
|
||||||
|
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||||
|
for key, ft in video_feats.items():
|
||||||
|
# create and save images
|
||||||
|
tmp_dir = root / "tmp_images"
|
||||||
|
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for frame_index in range(info["total_frames"]):
|
||||||
|
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||||
|
pil_img = PIL.Image.fromarray(img)
|
||||||
|
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
||||||
|
pil_img.save(path)
|
||||||
|
|
||||||
|
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||||
|
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
|
||||||
|
shutil.rmtree(tmp_dir)
|
||||||
|
|
||||||
|
return _create_video_directory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||||
def _create_hf_dataset(
|
def _create_hf_dataset(
|
||||||
@@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "image":
|
if ft["dtype"] == "image":
|
||||||
robot_cols[key] = [
|
robot_cols[key] = [
|
||||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||||
for _ in range(len(index_col))
|
for _ in range(len(index_col))
|
||||||
]
|
]
|
||||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||||
@@ -439,6 +473,7 @@ def lerobot_dataset_factory(
|
|||||||
hf_dataset: datasets.Dataset | None = None,
|
hf_dataset: datasets.Dataset | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
|
# Instantiate objects
|
||||||
if info is None:
|
if info is None:
|
||||||
info = info_factory(
|
info = info_factory(
|
||||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||||
@@ -448,19 +483,18 @@ def lerobot_dataset_factory(
|
|||||||
if tasks is None:
|
if tasks is None:
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if episodes_metadata is None:
|
if episodes_metadata is None:
|
||||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
|
||||||
episodes_metadata = episodes_factory(
|
episodes_metadata = episodes_factory(
|
||||||
features=info["features"],
|
features=info["features"],
|
||||||
fps=info["fps"],
|
fps=info["fps"],
|
||||||
total_episodes=info["total_episodes"],
|
total_episodes=info["total_episodes"],
|
||||||
total_frames=info["total_frames"],
|
total_frames=info["total_frames"],
|
||||||
video_keys=video_keys,
|
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
multi_task=multi_task,
|
multi_task=multi_task,
|
||||||
)
|
)
|
||||||
if not hf_dataset:
|
if not hf_dataset:
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
|
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
|
||||||
|
|
||||||
|
# Write data on disk
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
mock_snapshot_download = mock_snapshot_download_factory(
|
||||||
info=info,
|
info=info,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
|
|||||||
42
tests/fixtures/hub.py
vendored
42
tests/fixtures/hub.py
vendored
@@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import (
|
|||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_TASKS_PATH,
|
DEFAULT_TASKS_PATH,
|
||||||
|
DEFAULT_VIDEO_PATH,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
STATS_PATH,
|
STATS_PATH,
|
||||||
)
|
)
|
||||||
@@ -40,6 +41,7 @@ def mock_snapshot_download_factory(
|
|||||||
create_episodes,
|
create_episodes,
|
||||||
hf_dataset_factory,
|
hf_dataset_factory,
|
||||||
create_hf_dataset,
|
create_hf_dataset,
|
||||||
|
create_videos,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||||
@@ -91,40 +93,48 @@ def mock_snapshot_download_factory(
|
|||||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
|
||||||
|
for key in video_keys:
|
||||||
|
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
|
||||||
|
|
||||||
allowed_files = filter_repo_objects(
|
allowed_files = filter_repo_objects(
|
||||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||||
)
|
)
|
||||||
|
|
||||||
has_info = False
|
request_info = False
|
||||||
has_tasks = False
|
request_tasks = False
|
||||||
has_episodes = False
|
request_episodes = False
|
||||||
has_stats = False
|
request_stats = False
|
||||||
has_data = False
|
request_data = False
|
||||||
|
request_videos = False
|
||||||
for rel_path in allowed_files:
|
for rel_path in allowed_files:
|
||||||
if rel_path.startswith("meta/info.json"):
|
if rel_path.startswith("meta/info.json"):
|
||||||
has_info = True
|
request_info = True
|
||||||
elif rel_path.startswith("meta/stats"):
|
elif rel_path.startswith("meta/stats"):
|
||||||
has_stats = True
|
request_stats = True
|
||||||
elif rel_path.startswith("meta/tasks"):
|
elif rel_path.startswith("meta/tasks"):
|
||||||
has_tasks = True
|
request_tasks = True
|
||||||
elif rel_path.startswith("meta/episodes"):
|
elif rel_path.startswith("meta/episodes"):
|
||||||
has_episodes = True
|
request_episodes = True
|
||||||
elif rel_path.startswith("data/"):
|
elif rel_path.startswith("data/"):
|
||||||
has_data = True
|
request_data = True
|
||||||
|
elif rel_path.startswith("videos/"):
|
||||||
|
request_videos = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{rel_path} not supported.")
|
raise ValueError(f"{rel_path} not supported.")
|
||||||
|
|
||||||
if has_info:
|
if request_info:
|
||||||
create_info(local_dir, info)
|
create_info(local_dir, info)
|
||||||
if has_stats:
|
if request_stats:
|
||||||
create_stats(local_dir, stats)
|
create_stats(local_dir, stats)
|
||||||
if has_tasks:
|
if request_tasks:
|
||||||
create_tasks(local_dir, tasks)
|
create_tasks(local_dir, tasks)
|
||||||
if has_episodes:
|
if request_episodes:
|
||||||
create_episodes(local_dir, episodes)
|
create_episodes(local_dir, episodes)
|
||||||
# TODO(rcadene): create_videos?
|
if request_data:
|
||||||
if has_data:
|
|
||||||
create_hf_dataset(local_dir, hf_dataset)
|
create_hf_dataset(local_dir, hf_dataset)
|
||||||
|
if request_videos:
|
||||||
|
create_videos(root=local_dir, info=info)
|
||||||
|
|
||||||
return str(local_dir)
|
return str(local_dir)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user