Improve dataset v2 (#498)
This commit is contained in:
@@ -280,6 +280,8 @@ class LeRobotDatasetMetadata:
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
@@ -293,6 +295,7 @@ class LeRobotDatasetMetadata:
|
||||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
else:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
@@ -424,11 +427,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
self.consolidated = True
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = {}
|
||||
self.episode_buffer = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -451,12 +453,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
self.consolidated = self.meta.stats is not None
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
tags: list | None = None,
|
||||
text: str | None = None,
|
||||
license: str | None = "mit",
|
||||
license: str | None = "apache-2.0",
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
@@ -468,7 +474,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
create_repo(self.repo_id, repo_type="dataset", exist_ok=True)
|
||||
create_repo(
|
||||
repo_id=self.repo_id,
|
||||
private=private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
@@ -658,7 +670,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
return {
|
||||
"size": 0,
|
||||
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
|
||||
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
|
||||
}
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
@@ -681,8 +693,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
||||
# check the dtype and shape matches, etc.
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
@@ -723,6 +741,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError()
|
||||
|
||||
if episode_length == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
task_index = self.meta.get_task_index(task)
|
||||
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
@@ -781,7 +804,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
|
||||
|
||||
Reference in New Issue
Block a user