forked from tangger/lerobot
Validate features during add_frame + Add 2D-to-5D + Add string (#720)
This commit is contained in:
@@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
)
|
||||
|
||||
if image_array.dtype != np.uint8:
|
||||
# Assume the image is in [0, 1] range for floating-point data
|
||||
image_array = np.clip(image_array, 0, 1)
|
||||
if range_check:
|
||||
max_ = image_array.max().item()
|
||||
min_ = image_array.min().item()
|
||||
if max_ > 1.0 or min_ < 0.0:
|
||||
raise ValueError(
|
||||
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
||||
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
||||
"provide a uint8 image with values in the range [0, 255]."
|
||||
)
|
||||
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_image(image)
|
||||
img = image_array_to_pil_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
|
||||
@@ -39,6 +39,7 @@ from lerobot.common.datasets.utils import (
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
check_delta_timestamps,
|
||||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
@@ -724,10 +725,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
||||
# check the dtype and shape matches, etc.
|
||||
if "task" not in frame:
|
||||
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
|
||||
# Convert torch to numpy if needed
|
||||
for name in frame:
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
check_frame_features(frame, self.features)
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
@@ -757,8 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
@@ -815,12 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
|
||||
@@ -35,6 +35,7 @@ from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
@@ -203,7 +204,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
@@ -285,11 +286,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
elif len(ft["shape"]) == 1:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 4:
|
||||
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 5:
|
||||
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
else:
|
||||
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
@@ -606,3 +616,90 @@ class IterableNamespace(SimpleNamespace):
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def check_frame_features(frame: dict, features: dict):
|
||||
optional_features = {"timestamp"}
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = check_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += check_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def check_features_presence(
|
||||
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||
):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - (expected_features | optional_features)
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
if missing_features:
|
||||
error_message += f"Missing features: {missing_features}\n"
|
||||
if extra_features:
|
||||
error_message += f"Extra features: {extra_features}\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return check_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return check_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
actual_shape = value.shape
|
||||
|
||||
if actual_dtype != np.dtype(expected_dtype):
|
||||
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||
|
||||
if actual_shape != expected_shape:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def check_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
@@ -21,6 +21,7 @@ from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -200,5 +201,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
||||
return shape
|
||||
|
||||
|
||||
def has_method(cls: object, method_name: str):
|
||||
def has_method(cls: object, method_name: str) -> bool:
|
||||
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
||||
|
||||
|
||||
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
"""
|
||||
Return True if a given string can be converted to a numpy dtype.
|
||||
"""
|
||||
try:
|
||||
# Attempt to convert the string to a numpy dtype
|
||||
np.dtype(dtype_str)
|
||||
return True
|
||||
except TypeError:
|
||||
# If a TypeError is raised, the string is not a valid dtype
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user