forked from tangger/lerobot
Fix test_datasets
This commit is contained in:
@@ -48,7 +48,6 @@ from lerobot.common.datasets.utils import (
|
|||||||
embed_images,
|
embed_images,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
get_episode_data_index,
|
||||||
get_features_from_robot,
|
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_safe_version,
|
get_safe_version,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
@@ -72,7 +71,6 @@ from lerobot.common.datasets.video_utils import (
|
|||||||
get_safe_default_codec,
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robots import Robot
|
|
||||||
|
|
||||||
CODEBASE_VERSION = "v2.1"
|
CODEBASE_VERSION = "v2.1"
|
||||||
|
|
||||||
@@ -304,10 +302,9 @@ class LeRobotDatasetMetadata:
|
|||||||
cls,
|
cls,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
fps: int,
|
fps: int,
|
||||||
root: str | Path | None = None,
|
features: dict,
|
||||||
robot: Robot | None = None,
|
|
||||||
robot_type: str | None = None,
|
robot_type: str | None = None,
|
||||||
features: dict | None = None,
|
root: str | Path | None = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
) -> "LeRobotDatasetMetadata":
|
) -> "LeRobotDatasetMetadata":
|
||||||
"""Creates metadata for a LeRobotDataset."""
|
"""Creates metadata for a LeRobotDataset."""
|
||||||
@@ -317,33 +314,27 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
if robot is not None:
|
# if robot is not None:
|
||||||
features = get_features_from_robot(robot, use_videos)
|
# features = get_features_from_robot(robot, use_videos)
|
||||||
robot_type = robot.robot_type
|
# robot_type = robot.robot_type
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
# if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
# logging.warning(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
# f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
# "In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||||
)
|
# )
|
||||||
elif features is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# TODO(aliberts, rcadene): implement sanity check for features
|
|
||||||
features = {**features, **DEFAULT_FEATURES}
|
|
||||||
|
|
||||||
# check if none of the features contains a "/" in their names,
|
# TODO(aliberts, rcadene): implement sanity check for features
|
||||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
features = {**features, **DEFAULT_FEATURES}
|
||||||
for key in features:
|
|
||||||
if "/" in key:
|
|
||||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
|
||||||
|
|
||||||
features = {**features, **DEFAULT_FEATURES}
|
# check if none of the features contains a "/" in their names,
|
||||||
|
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||||
|
for key in features:
|
||||||
|
if "/" in key:
|
||||||
|
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||||
|
|
||||||
obj.tasks, obj.task_to_task_index = {}, {}
|
obj.tasks, obj.task_to_task_index = {}, {}
|
||||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
||||||
if len(obj.video_keys) > 0 and not use_videos:
|
if len(obj.video_keys) > 0 and not use_videos:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
write_json(obj.info, obj.root / INFO_PATH)
|
write_json(obj.info, obj.root / INFO_PATH)
|
||||||
@@ -986,10 +977,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
cls,
|
cls,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
fps: int,
|
fps: int,
|
||||||
|
features: dict,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
robot: Robot | None = None,
|
|
||||||
robot_type: str | None = None,
|
robot_type: str | None = None,
|
||||||
features: dict | None = None,
|
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
image_writer_processes: int = 0,
|
image_writer_processes: int = 0,
|
||||||
@@ -1001,10 +991,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.meta = LeRobotDatasetMetadata.create(
|
obj.meta = LeRobotDatasetMetadata.create(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
root=root,
|
|
||||||
robot=robot,
|
|
||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
features=features,
|
features=features,
|
||||||
|
root=root,
|
||||||
use_videos=use_videos,
|
use_videos=use_videos,
|
||||||
)
|
)
|
||||||
obj.repo_id = obj.meta.repo_id
|
obj.repo_id = obj.meta.repo_id
|
||||||
|
|||||||
@@ -477,9 +477,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
def create_empty_dataset_info(
|
def create_empty_dataset_info(
|
||||||
codebase_version: str,
|
codebase_version: str,
|
||||||
fps: int,
|
fps: int,
|
||||||
robot_type: str,
|
|
||||||
features: dict,
|
features: dict,
|
||||||
use_videos: bool,
|
use_videos: bool,
|
||||||
|
robot_type: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
|||||||
return KochFollowerConfig(**kwargs)
|
return KochFollowerConfig(**kwargs)
|
||||||
# elif robot_type == "koch_bimanual":
|
# elif robot_type == "koch_bimanual":
|
||||||
# return KochBimanualRobotConfig(**kwargs)
|
# return KochBimanualRobotConfig(**kwargs)
|
||||||
elif robot_type == "moss":
|
elif robot_type == "moss_follower":
|
||||||
from .moss_follower.configuration_moss import MossRobotConfig
|
from .moss_follower.configuration_moss import MossRobotConfig
|
||||||
|
|
||||||
return MossRobotConfig(**kwargs)
|
return MossRobotConfig(**kwargs)
|
||||||
elif robot_type == "so100_leader":
|
elif robot_type == "so100_follower":
|
||||||
from .so100_follower.config_so100_follower import SO100FollowerConfig
|
from .so100_follower.config_so100_follower import SO100FollowerConfig
|
||||||
|
|
||||||
return SO100FollowerConfig(**kwargs)
|
return SO100FollowerConfig(**kwargs)
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from lerobot.common.datasets.utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.envs.factory import make_env_config
|
from lerobot.common.envs.factory import make_env_config
|
||||||
from lerobot.common.policies.factory import make_policy_config
|
from lerobot.common.policies.factory import make_policy_config
|
||||||
from lerobot.common.robots.utils import make_robot
|
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
@@ -70,9 +69,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||||||
objects have the same sets of attributes defined.
|
objects have the same sets of attributes defined.
|
||||||
"""
|
"""
|
||||||
# Instantiate both ways
|
# Instantiate both ways
|
||||||
robot = make_robot("koch", mock=True)
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
root_create = tmp_path / "create"
|
root_create = tmp_path / "create"
|
||||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
|
||||||
|
|
||||||
root_init = tmp_path / "init"
|
root_init = tmp_path / "init"
|
||||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||||
@@ -100,22 +99,13 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
|||||||
assert dataset.num_frames == len(dataset)
|
assert dataset.num_frames == len(dataset)
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": torch.randn(1)})
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"task": "Dummy task"})
|
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -124,7 +114,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -133,7 +123,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -143,7 +133,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
ValueError,
|
ValueError,
|
||||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -155,7 +145,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact
|
|||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
|
dataset.add_frame({"state": 1.0}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -165,7 +155,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||||||
ValueError,
|
ValueError,
|
||||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -177,13 +167,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
|
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
@@ -195,7 +185,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2])
|
assert dataset[0]["state"].shape == torch.Size([2])
|
||||||
@@ -204,7 +194,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||||
@@ -213,7 +203,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||||
@@ -222,7 +212,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||||
@@ -231,7 +221,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||||
@@ -240,7 +230,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].ndim == 0
|
assert dataset[0]["state"].ndim == 0
|
||||||
@@ -249,7 +239,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["caption"] == "Dummy caption"
|
assert dataset[0]["caption"] == "Dummy caption"
|
||||||
@@ -264,7 +254,7 @@ def test_add_frame_image_wrong_shape(image_dataset):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
c, h, w = DUMMY_CHW
|
c, h, w = DUMMY_CHW
|
||||||
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
|
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_wrong_range(image_dataset):
|
def test_add_frame_image_wrong_range(image_dataset):
|
||||||
@@ -277,14 +267,14 @@ def test_add_frame_image_wrong_range(image_dataset):
|
|||||||
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
||||||
"""
|
"""
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image(image_dataset):
|
def test_add_frame_image(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -292,7 +282,7 @@ def test_add_frame_image(image_dataset):
|
|||||||
|
|
||||||
def test_add_frame_image_h_w_c(image_dataset):
|
def test_add_frame_image_h_w_c(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -301,7 +291,7 @@ def test_add_frame_image_h_w_c(image_dataset):
|
|||||||
def test_add_frame_image_uint8(image_dataset):
|
def test_add_frame_image_uint8(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
dataset.add_frame({"image": image}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -310,7 +300,7 @@ def test_add_frame_image_uint8(image_dataset):
|
|||||||
def test_add_frame_image_pil(image_dataset):
|
def test_add_frame_image_pil(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|||||||
Reference in New Issue
Block a user