Add video_info, fix image_writer
This commit is contained in:
@@ -22,10 +22,10 @@ from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
@@ -57,6 +57,7 @@ from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
@@ -391,7 +392,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return self.info["shapes"]
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
def features(self) -> list[str]:
|
||||
return list(self._features) + self.video_keys
|
||||
|
||||
@property
|
||||
def _features(self) -> datasets.Features:
|
||||
"""Features of the hf_dataset."""
|
||||
if self.hf_dataset is not None:
|
||||
return self.hf_dataset.features
|
||||
@@ -583,6 +588,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
)
|
||||
|
||||
if cam_key in self.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
@@ -592,7 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
|
||||
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
|
||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
@@ -608,7 +614,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for key in self.episode_buffer:
|
||||
if key in self.image_keys:
|
||||
continue
|
||||
if key in self.keys:
|
||||
elif key in self.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
@@ -619,6 +625,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos and len(self.video_keys) > 0:
|
||||
@@ -629,11 +637,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
|
||||
ep_table = ep_dataset._data.table
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(ep_table, ep_data_path)
|
||||
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = ep_dataset.format
|
||||
ep_dataset = ep_dataset.with_format("arrow")
|
||||
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
|
||||
ep_dataset = ep_dataset.with_format(**format)
|
||||
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
@@ -677,7 +691,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
if isinstance(self.image_writer, ImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||
@@ -689,18 +703,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def stop_image_writter(self) -> None:
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
self.image_writer.shutdown()
|
||||
self.image_writer = None
|
||||
|
||||
def _wait_image_writer(self) -> None:
|
||||
"""Wait for asynchronous image writer to finish."""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in range(self.num_episodes):
|
||||
for episode_index in range(self.total_episodes):
|
||||
for key in self.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
@@ -713,6 +732,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
def _write_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if key not in self.info["videos"]:
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["videos"][key] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
@@ -720,12 +751,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self._write_video_info()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.dir)
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writter()
|
||||
self.stop_image_writer()
|
||||
self.stats = compute_stats(self)
|
||||
write_stats(self.stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
@@ -735,7 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
# TODO(aliberts)
|
||||
# - [ ] add video info in info.json
|
||||
# - [X] add video info in info.json
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
@@ -775,7 +807,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||
obj.start_image_writter(
|
||||
obj.start_image_writer(
|
||||
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||
)
|
||||
elif (
|
||||
@@ -791,7 +823,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
if len(video_keys) > 0 and not use_videos:
|
||||
raise ValueError
|
||||
raise ValueError()
|
||||
|
||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(
|
||||
@@ -918,7 +950,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
||||
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
|
||||
return features
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user