Add record

This commit is contained in:
Simon Alibert
2025-05-08 13:15:37 +02:00
parent 237b14a6ec
commit 8b98399206
4 changed files with 387 additions and 25 deletions

View File

@@ -40,7 +40,7 @@ from lerobot.common.datasets.backward_compatibility import (
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.common.robots.utils import Robot
from lerobot.common.robots import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
@@ -387,6 +387,52 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
return datasets.Features(hf_features)
def _validate_feature_names(features: dict[str, dict]) -> None:
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
features = {}
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts:
features[f"{prefix}.joints"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
for key, shape in cam_fts.items():
features[f"{prefix}.cameras.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
continue
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.cameras.")]
return frame
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {}
if robot.cameras:
@@ -699,16 +745,12 @@ class IterableNamespace(SimpleNamespace):
def validate_frame(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
error_message = validate_features_presence(actual_features, expected_features, optional_features)
error_message = validate_features_presence(actual_features, expected_features)
if "task" in frame:
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
common_features = actual_features & expected_features
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
@@ -716,12 +758,10 @@ def validate_frame(frame: dict, features: dict):
raise ValueError(error_message)
def validate_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - (expected_features | optional_features)
extra_features = actual_features - expected_features
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"