Hardware API redesign (#777)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
@@ -38,6 +38,7 @@ from lerobot.common.datasets.utils import (
|
||||
DEFAULT_IMAGE_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
_validate_feature_names,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
@@ -48,7 +49,6 @@ from lerobot.common.datasets.utils import (
|
||||
embed_images,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_safe_version,
|
||||
hf_transform_to_torch,
|
||||
@@ -72,7 +72,6 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
|
||||
@@ -304,10 +303,9 @@ class LeRobotDatasetMetadata:
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
features: dict,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
@@ -317,33 +315,13 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
)
|
||||
elif features is None:
|
||||
raise ValueError(
|
||||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
else:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
@@ -785,7 +763,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
self.image_writer.save_image(image=image, fpath=fpath)
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
@@ -803,17 +781,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
if timestamp is None:
|
||||
timestamp = frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(task)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key == "task":
|
||||
# Note: we associate the task in natural language to its task index during `save_episode`
|
||||
self.episode_buffer["task"].append(frame["task"])
|
||||
continue
|
||||
|
||||
if key not in self.features:
|
||||
raise ValueError(
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
@@ -989,10 +964,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
@@ -1004,10 +978,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.common.datasets.backward_compatibility import (
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robots import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
||||
@@ -387,6 +387,59 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
||||
if invalid_features:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
||||
|
||||
|
||||
def hw_to_dataset_features(
|
||||
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
features = {}
|
||||
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == "action":
|
||||
features[prefix] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == "observation":
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
|
||||
def build_dataset_frame(
|
||||
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
||||
) -> dict[str, np.ndarray]:
|
||||
frame = {}
|
||||
for key, ft in ds_features.items():
|
||||
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
||||
continue
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
@@ -415,7 +468,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
elif key == "action":
|
||||
elif key.startswith("action"):
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
@@ -431,9 +484,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
robot_type: str,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
@@ -699,16 +752,12 @@ class IterableNamespace(SimpleNamespace):
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
optional_features = {"timestamp"}
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
error_message = validate_features_presence(actual_features, expected_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
common_features = actual_features & expected_features
|
||||
for name in common_features - {"task"}:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
@@ -716,12 +765,10 @@ def validate_frame(frame: dict, features: dict):
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(
|
||||
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||
):
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - (expected_features | optional_features)
|
||||
extra_features = actual_features - expected_features
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
|
||||
@@ -27,7 +27,7 @@ from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
||||
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
||||
from lerobot.common.robots.aloha.configuration_aloha import AlohaRobotConfig
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
@@ -141,8 +141,7 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_config
|
||||
from lerobot.common.robots import RobotConfig
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
@@ -598,6 +597,30 @@ def convert_dataset(
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
|
||||
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
if robot_type == "aloha":
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
elif robot_type == "koch_follower":
|
||||
from lerobot.common.robots.koch_follower import KochFollowerConfig
|
||||
|
||||
return KochFollowerConfig(**kwargs)
|
||||
elif robot_type == "so100_follower":
|
||||
from lerobot.common.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
return SO100FollowerConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
from lerobot.common.robots.stretch3 import Stretch3RobotConfig
|
||||
|
||||
return Stretch3RobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
from lerobot.common.robots.lekiwi import LeKiwiConfig
|
||||
|
||||
return LeKiwiConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
task_args = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
Reference in New Issue
Block a user