Move default paths, use jsonlines for tasks
This commit is contained in:
@@ -28,6 +28,13 @@ from huggingface_hub import DatasetCard, HfApi
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
)
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
@@ -145,7 +152,7 @@ def load_hf_dataset(
|
||||
|
||||
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
"""Loads metadata files from a dataset."""
|
||||
info_path = local_dir / "meta/info.json"
|
||||
info_path = local_dir / "meta/info.jsonl"
|
||||
episodes_path = local_dir / "meta/episodes.jsonl"
|
||||
stats_path = local_dir / "meta/stats.json"
|
||||
tasks_path = local_dir / "meta/tasks.json"
|
||||
@@ -159,8 +166,8 @@ def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
with open(stats_path) as f:
|
||||
stats = json.load(f)
|
||||
|
||||
with open(tasks_path) as f:
|
||||
tasks = json.load(f)
|
||||
with jsonlines.open(tasks_path, "r") as reader:
|
||||
tasks = list(reader)
|
||||
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = unflatten_dict(stats)
|
||||
@@ -169,6 +176,28 @@ def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
return info, episode_dicts, stats, tasks
|
||||
|
||||
|
||||
def create_dataset_info(codebase_version: str, fps: int, robot: Robot) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"robot_type": robot.robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"total_videos": 0,
|
||||
"total_chunks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
# "keys": keys,
|
||||
# "video_keys": video_keys,
|
||||
# "image_keys": image_keys,
|
||||
# "shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
||||
# "names": names,
|
||||
# "videos": {"videos_path": DEFAULT_VIDEO_PATH} if video_keys else None,
|
||||
}
|
||||
|
||||
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
if episodes is not None:
|
||||
|
||||
Reference in New Issue
Block a user