Add download_metadata, move default paths

This commit is contained in:
Simon Alibert
2024-10-18 14:48:34 +02:00
parent e7355ba595
commit 1a51505ec6
2 changed files with 68 additions and 64 deletions

View File

@@ -120,6 +120,11 @@ from huggingface_hub.errors import EntryNotFoundError
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.lerobot_dataset import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
)
from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
@@ -127,15 +132,8 @@ from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
V16 = "v1.6"
V20 = "v2.0"
EPISODE_CHUNK_SIZE = 1000
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
PARQUET_CHUNK_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
)
VIDEO_CHUNK_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
@@ -269,15 +267,15 @@ def split_parquet_by_episodes(
table = dataset.remove_columns(keys["video"])._data.table
episode_lengths = []
for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(PARQUET_CHUNK_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / PARQUET_CHUNK_PATH.format(
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
)
pq.write_table(ep_table, output_file)
@@ -323,16 +321,16 @@ def move_videos(
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys:
chunk_dir = "/".join(VIDEO_CHUNK_PATH.split("/")[:-1]).format(
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key
)
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = VIDEO_CHUNK_PATH.format(
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
@@ -476,11 +474,12 @@ def _get_video_info(video_path: Path | str) -> dict:
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
hub_api = HfApi()
videos_info_dict = {"videos_path": VIDEO_CHUNK_PATH}
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
# Assumes first episode
video_files = [
VIDEO_CHUNK_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys
]
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
@@ -587,8 +586,8 @@ def convert_dataset(
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(keys["video"])
total_chunks = total_episodes // EPISODE_CHUNK_SIZE
if total_episodes % EPISODE_CHUNK_SIZE != 0:
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1
# Tasks
@@ -670,14 +669,14 @@ def convert_dataset(
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V20,
"data_path": PARQUET_CHUNK_PATH,
"data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": len(dataset),
"total_tasks": len(tasks),
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": EPISODE_CHUNK_SIZE,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"keys": keys["sequence"],