Add extra info to dataset card, various fixes from Remi's review
This commit is contained in:
@@ -27,7 +27,7 @@ import PIL.Image
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
from huggingface_hub import create_repo, snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
@@ -44,6 +44,7 @@ from lerobot.common.datasets.utils import (
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
@@ -54,9 +55,9 @@ from lerobot.common.datasets.utils import (
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
serialize_dict,
|
||||
write_json,
|
||||
write_parquet,
|
||||
write_stats,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
@@ -75,11 +76,11 @@ class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Load metadata
|
||||
@@ -163,7 +164,7 @@ class LeRobotDatasetMetadata:
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list[str]]:
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return {key: ft["names"] for key, ft in self.features.items()}
|
||||
|
||||
@@ -209,7 +210,7 @@ class LeRobotDatasetMetadata:
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def add_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
@@ -238,24 +239,37 @@ class LeRobotDatasetMetadata:
|
||||
self.episodes.append(episode_dict)
|
||||
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
|
||||
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
|
||||
# ep_stats = serialize_dict(ep_stats)
|
||||
# append_jsonlines(ep_stats, self.root / STATS_PATH)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if key not in self.info["videos"]:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["videos"][key] = get_video_info(video_path)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}\n"
|
||||
f"Repository ID: '{self.repo_id}',\n"
|
||||
f"{json.dumps(self.meta.info, indent=4)}\n"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
@@ -264,7 +278,7 @@ class LeRobotDatasetMetadata:
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
@@ -294,7 +308,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
@@ -402,7 +416,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
@@ -437,22 +451,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
def push_to_hub(self, push_videos: bool = True) -> None:
|
||||
def push_to_hub(
|
||||
self,
|
||||
tags: list | None = None,
|
||||
text: str | None = None,
|
||||
license: str | None = "mit",
|
||||
push_videos: bool = True,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
||||
"Please call the dataset 'consolidate()' method first."
|
||||
)
|
||||
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
create_repo(self.repo_id, repo_type="dataset", exist_ok=True)
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text, info=self.meta.info, license=license)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
@@ -501,8 +525,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
# return hf_dataset.with_format("torch") TODO
|
||||
|
||||
return hf_dataset
|
||||
|
||||
@property
|
||||
@@ -653,30 +678,33 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'add_episode()' method
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
frame_index = self.episode_buffer["size"]
|
||||
for key, ft in self.features.items():
|
||||
if key == "frame_index":
|
||||
self.episode_buffer[key].append(frame_index)
|
||||
elif key == "timestamp":
|
||||
self.episode_buffer[key].append(frame_index / self.fps)
|
||||
elif key in frame and ft["dtype"] not in ["image", "video"]:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
elif key in frame and ft["dtype"] in ["image", "video"]:
|
||||
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
raise ValueError(key)
|
||||
|
||||
if self.features[key]["dtype"] not in ["image", "video"]:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
if ft["dtype"] == "image":
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
@@ -686,49 +714,56 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
episode_length = self.episode_buffer.pop("size")
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
episode_length = episode_buffer.pop("size")
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError()
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
|
||||
if not set(self.episode_buffer.keys()) == set(self.features):
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError()
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key == "index":
|
||||
self.episode_buffer[key] = np.arange(
|
||||
episode_buffer[key] = np.arange(
|
||||
self.meta.total_frames, self.meta.total_frames + episode_length
|
||||
)
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
episode_buffer[key] = np.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
self.episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
episode_buffer[key] = np.full((episode_length,), task_index)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["shape"][0] == 1:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
elif ft["shape"][0] > 1:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.meta.add_episode(episode_index, episode_length, task, task_index)
|
||||
raise ValueError(key)
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_index)
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
|
||||
self.meta.save_episode(episode_index, episode_length, task, task_index)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.hf_features, split="train")
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_parquet(ep_dataset, ep_data_path)
|
||||
@@ -777,16 +812,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
for episode_index in range(self.meta.total_episodes):
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
for ep_idx in range(self.meta.total_episodes):
|
||||
self.encode_episode_videos(ep_idx)
|
||||
|
||||
def encode_episode_videos(self, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
video_paths = {}
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
video_paths[key] = str(video_path)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
@@ -810,27 +857,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
self.meta.stats = compute_stats(self)
|
||||
write_stats(self.meta.stats, self.root / STATS_PATH)
|
||||
serialized_stats = serialize_dict(self.meta.stats)
|
||||
write_json(serialized_stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning(
|
||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||
)
|
||||
|
||||
# TODO(aliberts)
|
||||
# - [X] add video info in info.json
|
||||
# Sanity checks:
|
||||
# - [X] number of files
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: Path | None = None,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
|
||||
Reference in New Issue
Block a user