Improve dataset v2 (#498)

This commit is contained in:
Remi
2024-11-19 12:31:47 +01:00
committed by GitHub
parent acae4b49d2
commit 1f13bda25b
9 changed files with 393 additions and 70 deletions

View File

@@ -280,6 +280,8 @@ class LeRobotDatasetMetadata:
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
robot_type = robot.robot_type
@@ -293,6 +295,7 @@ class LeRobotDatasetMetadata:
"Dataset features must either come from a Robot or explicitly passed upon creation."
)
else:
# TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, []
@@ -424,11 +427,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only
self.consolidated = True
# Unused attributes
self.image_writer = None
self.episode_buffer = {}
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
@@ -451,12 +453,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None
def push_to_hub(
self,
tags: list | None = None,
text: str | None = None,
license: str | None = "mit",
license: str | None = "apache-2.0",
push_videos: bool = True,
private: bool = False,
) -> None:
if not self.consolidated:
raise RuntimeError(
@@ -468,7 +474,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not push_videos:
ignore_patterns.append("videos/")
create_repo(self.repo_id, repo_type="dataset", exist_ok=True)
create_repo(
repo_id=self.repo_id,
private=private,
repo_type="dataset",
exist_ok=True,
)
upload_folder(
repo_id=self.repo_id,
folder_path=self.root,
@@ -658,7 +670,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return {
"size": 0,
**{key: [] if key != "episode_index" else current_ep_idx for key in self.features},
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
}
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
@@ -681,8 +693,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called.
"""
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.
if self.episode_buffer is None:
self.episode_buffer = self._create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
@@ -723,6 +741,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError()
if episode_length == 0:
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
task_index = self.meta.get_task_index(task)
if not set(episode_buffer.keys()) == set(self.features):
@@ -781,7 +804,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
if isinstance(self.image_writer, AsyncImageWriter):
logging.warning(
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import textwrap
import warnings
from itertools import accumulate
from pathlib import Path
@@ -139,6 +140,8 @@ def load_info(local_dir: Path) -> dict:
def load_stats(local_dir: Path) -> dict:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
@@ -186,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]:
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
or open an issue on GitHub.
""")
super().__init__(message)
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
warnings.warn(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
@@ -207,18 +230,16 @@ def check_version_compatibility(
)
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
def get_hub_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
@@ -461,6 +482,7 @@ def create_lerobot_dataset_card(
}
]
card.data.task_categories = ["robotics"]
card.data.license = license
card.data.tags = ["LeRobot"]
if license:
card.data.license = license

View File

@@ -441,7 +441,7 @@ def convert_dataset(
arxiv: str | None = None,
test_branch: str | None = None,
):
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
v1 = get_hub_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -17,6 +17,7 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
@@ -330,3 +331,21 @@ def sanity_check_dataset_name(repo_id, policy):
raise ValueError(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
)
def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos):
fields = [
("robot_type", dataset.meta.info["robot_type"], robot.robot_type),
("fps", dataset.meta.info["fps"], fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
]
mismatches = []
for field, dataset_value, present_value in fields:
if dataset_value != present_value:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)

View File

@@ -115,6 +115,7 @@ from lerobot.common.robot_devices.control_utils import (
record_episode,
reset_environment,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
warmup_record,
)
@@ -207,6 +208,9 @@ def record(
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
@@ -232,17 +236,29 @@ def record(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
)
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera,
)
if resume:
dataset = LeRobotDataset(
repo_id,
root=root,
local_files_only=local_files_only,
)
dataset.start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
if not robot.is_connected:
robot.connect()
@@ -270,8 +286,7 @@ def record(
# if multi_task:
# task = input("Enter your task description: ")
episode_index = dataset.episode_buffer["episode_index"]
log_say(f"Recording episode {episode_index}", play_sounds)
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
record_episode(
dataset=dataset,
robot=robot,
@@ -289,7 +304,7 @@ def record(
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(episode_index < num_episodes - 1) or events["rerecord_episode"]
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)

View File

@@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
def push_dataset_card_to_hub(
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
repo_id: str,
revision: str | None,
tags: list | None = None,
text: str | None = None,
license: str = "apache-2.0",
):
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
card = create_lerobot_dataset_card(tags=tags, text=text)
card = create_lerobot_dataset_card(tags=tags, text=text, license=license)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)