Add video_info, fix image_writer

This commit is contained in:
Simon Alibert
2024-10-25 16:55:33 +02:00
parent 18ffa4248b
commit e210d795de
6 changed files with 241 additions and 180 deletions

View File

@@ -22,10 +22,10 @@ from pathlib import Path
from typing import Callable
import datasets
import pyarrow.parquet as pq
import torch
import torch.utils
from datasets import load_dataset
from datasets.table import embed_table_storage
from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
@@ -57,6 +57,7 @@ from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames_torchvision,
encode_video_frames,
get_video_info,
)
from lerobot.common.robot_devices.robots.utils import Robot
@@ -391,7 +392,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.info["shapes"]
@property
def features(self) -> datasets.Features:
def features(self) -> list[str]:
return list(self._features) + self.video_keys
@property
def _features(self) -> datasets.Features:
"""Features of the hf_dataset."""
if self.hf_dataset is not None:
return self.hf_dataset.features
@@ -583,6 +588,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image=frame[cam_key],
file_path=img_path,
)
if cam_key in self.image_keys:
self.episode_buffer[cam_key].append(str(img_path))
@@ -592,7 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
"""
@@ -608,7 +614,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key in self.episode_buffer:
if key in self.image_keys:
continue
if key in self.keys:
elif key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
@@ -619,6 +625,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self._wait_image_writer()
self._save_episode_table(episode_index)
if encode_videos and len(self.video_keys) > 0:
@@ -629,11 +637,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.consolidated = False
def _save_episode_table(self, episode_index: int) -> None:
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
ep_table = ep_dataset._data.table
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(ep_table, ep_data_path)
# Embed image bytes into the table before saving to parquet
format = ep_dataset.format
ep_dataset = ep_dataset.with_format("arrow")
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
ep_dataset = ep_dataset.with_format(**format)
ep_dataset.to_parquet(ep_data_path)
def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int
@@ -677,7 +691,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
if isinstance(self.image_writer, ImageWriter):
logging.warning(
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
@@ -689,18 +703,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
num_threads=num_threads,
)
def stop_image_writter(self) -> None:
def stop_image_writer(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()
self.image_writer.shutdown()
self.image_writer = None
def _wait_image_writer(self) -> None:
"""Wait for asynchronous image writer to finish."""
if self.image_writer is not None:
self.image_writer.wait()
def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(self.num_episodes):
for episode_index in range(self.total_episodes):
for key in self.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
@@ -713,6 +732,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
def _write_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in self.video_keys:
if key not in self.info["videos"]:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["videos"][key] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
@@ -720,12 +751,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if len(self.video_keys) > 0:
self.encode_videos()
self._write_video_info()
if not keep_image_files and self.image_writer is not None:
shutil.rmtree(self.image_writer.dir)
if run_compute_stats:
self.stop_image_writter()
self.stop_image_writer()
self.stats = compute_stats(self)
write_stats(self.stats, self.root / STATS_PATH)
self.consolidated = True
@@ -735,7 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
# TODO(aliberts)
# - [ ] add video info in info.json
# - [X] add video info in info.json
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
@@ -775,7 +807,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
obj.start_image_writer(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)
elif (
@@ -791,7 +823,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
if len(video_keys) > 0 and not use_videos:
raise ValueError
raise ValueError()
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(
@@ -918,7 +950,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
return features
@property