Add local_files_only, encode_videos, fix bugs to pass tests (WIP)
This commit is contained in:
@@ -17,6 +17,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -30,20 +31,32 @@ from huggingface_hub import snapshot_download, upload_folder
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonl,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_metadata,
|
||||
load_episode_dicts,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
@@ -61,6 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
image_writer: ImageWriter | None = None,
|
||||
):
|
||||
@@ -162,21 +176,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.download_videos = download_videos
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.image_writer = image_writer
|
||||
self.delta_indices = None
|
||||
self.consolidated = True
|
||||
self.episode_buffer = {}
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.episode_dicts = load_episode_dicts(self.root)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes()
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
|
||||
@@ -199,6 +218,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return self.info["codebase_version"]
|
||||
|
||||
def push_to_repo(self, push_videos: bool = True) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
@@ -225,13 +253,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
revision=self._hub_version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
def download_episodes(self) -> None:
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
@@ -240,10 +269,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if self.download_videos else "videos/"
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
if len(self.video_keys) > 0 and self.download_videos:
|
||||
if len(self.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
self.get_video_file_path(ep_idx, vid_key)
|
||||
for vid_key in self.video_keys
|
||||
@@ -495,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.camera_keys if self.download_videos else self.image_keys
|
||||
image_keys = self.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
@@ -521,6 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"timestamp": [],
|
||||
"next.done": [],
|
||||
**{key: [] for key in self.keys},
|
||||
**{key: [] for key in self.image_keys},
|
||||
}
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
@@ -553,6 +583,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
)
|
||||
if cam_key in self.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
"""
|
||||
@@ -574,6 +606,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.image_keys:
|
||||
continue
|
||||
if key in self.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
@@ -583,11 +617,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
|
||||
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos:
|
||||
pass # TODO
|
||||
if encode_videos and len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
@@ -614,7 +649,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
|
||||
append_jsonl(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
@@ -622,22 +657,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / "meta/info.json")
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
|
||||
self.episode_dicts.append(episode_dict)
|
||||
append_jsonl(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
def delete_episode(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.camera_keys:
|
||||
cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
if cam_dir.is_dir():
|
||||
shutil.rmtree(cam_dir)
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
@@ -653,27 +689,54 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
updated_file_name = self.get_data_file_path(ep_idx)
|
||||
current_file_name.rename(updated_file_name)
|
||||
|
||||
def _remove_image_writer(self) -> None:
|
||||
if self.image_writer is not None:
|
||||
self.image_writer = None
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in range(self.num_episodes):
|
||||
for key in self.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||
video_path = self.get_video_file_path(episode_index, key, return_str=False)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||
self._update_data_file_names()
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self._remove_image_writer()
|
||||
self.stats = compute_stats(self)
|
||||
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
|
||||
serialized_stats = flatten_dict(self.stats)
|
||||
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, self.root / "meta/stats.json")
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning("Skipping computation of the dataset statistics.")
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
pass # TODO
|
||||
# TODO(aliberts)
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
# - [ ] number of files
|
||||
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
||||
# - [ ] no remaining self.image_writer.dir
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -691,7 +754,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj._version = CODEBASE_VERSION
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = image_writer
|
||||
|
||||
@@ -702,21 +764,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / "meta/info.json")
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
|
||||
# It is used to know when certain operations are need (for instance, computing dataset statistics).
|
||||
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.local_files_only = True
|
||||
obj.download_videos = False
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
Reference in New Issue
Block a user