Add local_files_only, encode_videos, fix bugs to pass tests (WIP)

This commit is contained in:
Simon Alibert
2024-10-22 19:57:52 +02:00
parent e991a31061
commit a805458c7e
4 changed files with 183 additions and 80 deletions

View File

@@ -17,6 +17,7 @@ import json
import logging
import os
import shutil
from functools import cached_property
from pathlib import Path
from typing import Callable
@@ -30,20 +31,32 @@ from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
TASKS_PATH,
append_jsonl,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
create_branch,
create_empty_dataset_info,
flatten_dict,
get_delta_indices,
get_episode_data_index,
get_hub_safe_version,
hf_transform_to_torch,
load_metadata,
load_episode_dicts,
load_info,
load_stats,
load_tasks,
unflatten_dict,
write_json,
)
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames_torchvision,
encode_video_frames,
)
from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
@@ -61,6 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
image_writer: ImageWriter | None = None,
):
@@ -162,21 +176,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.delta_indices = None
self.consolidated = True
self.episode_buffer = {}
self.local_files_only = local_files_only
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.pull_from_repo(allow_patterns="meta/")
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
self.episode_dicts = load_episode_dicts(self.root)
# Check version
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
# Load actual data
self.download_episodes()
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
@@ -199,6 +218,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property
def _version(self) -> str:
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
def push_to_repo(self, push_videos: bool = True) -> None:
if not self.consolidated:
raise RuntimeError(
@@ -225,13 +253,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
revision=self._hub_version,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)
def download_episodes(self) -> None:
def download_episodes(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@@ -240,10 +269,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
ignore_patterns = None if self.download_videos else "videos/"
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
if len(self.video_keys) > 0 and self.download_videos:
if len(self.video_keys) > 0 and download_videos:
video_files = [
self.get_video_file_path(ep_idx, vid_key)
for vid_key in self.video_keys
@@ -495,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.camera_keys if self.download_videos else self.image_keys
image_keys = self.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
@@ -521,6 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"timestamp": [],
"next.done": [],
**{key: [] for key in self.keys},
**{key: [] for key in self.image_keys},
}
def add_frame(self, frame: dict) -> None:
@@ -553,6 +583,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
image=frame[cam_key],
file_path=img_path,
)
if cam_key in self.image_keys:
self.episode_buffer[cam_key].append(str(img_path))
def add_episode(self, task: str, encode_videos: bool = False) -> None:
"""
@@ -574,6 +606,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["next.done"][-1] = True
for key in self.episode_buffer:
if key in self.image_keys:
continue
if key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index":
@@ -583,11 +617,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self._save_episode_table(episode_index)
if encode_videos:
pass # TODO
if encode_videos and len(self.video_keys) > 0:
self.encode_videos()
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
@@ -614,7 +649,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"task_index": task_index,
"task": task,
}
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
append_jsonl(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
@@ -622,22 +657,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / "meta/info.json")
write_json(self.info, self.root / INFO_PATH)
episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"length": episode_length,
}
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
self.episode_dicts.append(episode_dict)
append_jsonl(episode_dict, self.root / EPISODES_PATH)
def delete_episode(self) -> None:
episode_index = self.episode_buffer["episode_index"]
if self.image_writer is not None:
for cam_key in self.camera_keys:
cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
if cam_dir.is_dir():
shutil.rmtree(cam_dir)
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False)
if img_dir.is_dir():
shutil.rmtree(img_dir)
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
@@ -653,27 +689,54 @@ class LeRobotDataset(torch.utils.data.Dataset):
updated_file_name = self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name)
def _remove_image_writer(self) -> None:
if self.image_writer is not None:
self.image_writer = None
def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(self.num_episodes):
for key in self.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
video_path = self.get_video_file_path(episode_index, key, return_str=False)
if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def consolidate(self, run_compute_stats: bool = True) -> None:
self._update_data_file_names()
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.video_keys) > 0:
self.encode_videos()
if run_compute_stats:
logging.info("Computing dataset statistics")
self.hf_dataset = self.load_hf_dataset()
self._remove_image_writer()
self.stats = compute_stats(self)
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
serialized_stats = flatten_dict(self.stats)
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, self.root / "meta/stats.json")
self.consolidated = True
else:
logging.warning("Skipping computation of the dataset statistics.")
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
pass # TODO
# TODO(aliberts)
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
# - [ ] number of files
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
# - [ ] no remaining self.image_writer.dir
self.consolidated = True
@classmethod
def create(
@@ -691,7 +754,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer
@@ -702,21 +764,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json")
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
write_json(obj.info, obj.root / INFO_PATH)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
# It is used to know when certain operations are need (for instance, computing dataset statistics).
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.local_files_only = True
obj.download_videos = False
obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj