forked from tangger/lerobot
(fix): test
This commit is contained in:
@@ -0,0 +1,207 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Utilities
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from copy import copy
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from deepdiff import DeepDiff
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import DEFAULT_FEATURES
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.common.robots import Robot
|
||||||
|
|
||||||
|
|
||||||
|
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||||
|
log_items = []
|
||||||
|
if episode_index is not None:
|
||||||
|
log_items.append(f"ep:{episode_index}")
|
||||||
|
if frame_index is not None:
|
||||||
|
log_items.append(f"frame:{frame_index}")
|
||||||
|
|
||||||
|
def log_dt(shortname, dt_val_s):
|
||||||
|
nonlocal log_items, fps
|
||||||
|
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||||
|
if fps is not None:
|
||||||
|
actual_fps = 1 / dt_val_s
|
||||||
|
if actual_fps < fps - 1:
|
||||||
|
info_str = colored(info_str, "yellow")
|
||||||
|
log_items.append(info_str)
|
||||||
|
|
||||||
|
# total step time displayed in milliseconds and its frequency
|
||||||
|
log_dt("dt", dt_s)
|
||||||
|
|
||||||
|
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||||
|
if not robot.robot_type.startswith("stretch"):
|
||||||
|
for name in robot.leader_arms:
|
||||||
|
key = f"read_leader_{name}_pos_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt("dtRlead", robot.logs[key])
|
||||||
|
|
||||||
|
for name in robot.follower_arms:
|
||||||
|
key = f"write_follower_{name}_goal_pos_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt("dtWfoll", robot.logs[key])
|
||||||
|
|
||||||
|
key = f"read_follower_{name}_pos_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt("dtRfoll", robot.logs[key])
|
||||||
|
|
||||||
|
for name in robot.cameras:
|
||||||
|
key = f"read_camera_{name}_dt_s"
|
||||||
|
if key in robot.logs:
|
||||||
|
log_dt(f"dtR{name}", robot.logs[key])
|
||||||
|
|
||||||
|
info_str = " ".join(log_items)
|
||||||
|
logging.info(info_str)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def is_headless():
|
||||||
|
"""Detects if python is running without a monitor."""
|
||||||
|
try:
|
||||||
|
import pynput # noqa
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
print(
|
||||||
|
"Error trying to import pynput. Switching to headless mode. "
|
||||||
|
"As a result, the video stream from the cameras won't be shown, "
|
||||||
|
"and you won't be able to change the control flow with keyboards. "
|
||||||
|
"For more info, see traceback below.\n"
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
print()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def predict_action(
|
||||||
|
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
|
||||||
|
):
|
||||||
|
observation = copy(observation)
|
||||||
|
with (
|
||||||
|
torch.inference_mode(),
|
||||||
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
|
):
|
||||||
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
|
for name in observation:
|
||||||
|
observation[name] = torch.from_numpy(observation[name])
|
||||||
|
if "image" in name:
|
||||||
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
# Compute the next action with the policy
|
||||||
|
# based on the current observation
|
||||||
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
# Remove batch dimension
|
||||||
|
action = action.squeeze(0)
|
||||||
|
|
||||||
|
# Move to cpu, if not already the case
|
||||||
|
action = action.to("cpu")
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
def init_keyboard_listener():
|
||||||
|
# Allow to exit early while recording an episode or resetting the environment,
|
||||||
|
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||||
|
# to allow your terminal to monitor keyboard events.
|
||||||
|
events = {}
|
||||||
|
events["exit_early"] = False
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["stop_recording"] = False
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
logging.warning(
|
||||||
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
listener = None
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
# Only import pynput if not in a headless environment
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.right:
|
||||||
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
|
events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.left:
|
||||||
|
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||||
|
events["rerecord_episode"] = True
|
||||||
|
events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.esc:
|
||||||
|
print("Escape key pressed. Stopping data recording...")
|
||||||
|
events["stop_recording"] = True
|
||||||
|
events["exit_early"] = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
|
listener.start()
|
||||||
|
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
|
||||||
|
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||||
|
_, dataset_name = repo_id.split("/")
|
||||||
|
# either repo_id doesnt start with "eval_" and there is no policy
|
||||||
|
# or repo_id starts with "eval_" and there is a policy
|
||||||
|
|
||||||
|
# Check if dataset_name starts with "eval_" but policy is missing
|
||||||
|
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||||
|
if not dataset_name.startswith("eval_") and policy_cfg is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sanity_check_dataset_robot_compatibility(
|
||||||
|
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
|
||||||
|
) -> None:
|
||||||
|
fields = [
|
||||||
|
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||||
|
("fps", dataset.fps, fps),
|
||||||
|
("features", dataset.features, {**features, **DEFAULT_FEATURES}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mismatches = []
|
||||||
|
for field, dataset_value, present_value in fields:
|
||||||
|
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||||||
|
if diff:
|
||||||
|
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||||
|
|
||||||
|
if mismatches:
|
||||||
|
raise ValueError(
|
||||||
|
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||||
|
)
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from lerobot.common.datasets.utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.envs.factory import make_env_config
|
from lerobot.common.envs.factory import make_env_config
|
||||||
from lerobot.common.policies.factory import make_policy_config
|
from lerobot.common.policies.factory import make_policy_config
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
@@ -70,9 +69,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||||||
objects have the same sets of attributes defined.
|
objects have the same sets of attributes defined.
|
||||||
"""
|
"""
|
||||||
# Instantiate both ways
|
# Instantiate both ways
|
||||||
robot = make_robot("koch", mock=True)
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
root_create = tmp_path / "create"
|
root_create = tmp_path / "create"
|
||||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
|
||||||
|
|
||||||
root_init = tmp_path / "init"
|
root_init = tmp_path / "init"
|
||||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||||
@@ -100,22 +99,13 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
|||||||
assert dataset.num_frames == len(dataset)
|
assert dataset.num_frames == len(dataset)
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
|
||||||
):
|
|
||||||
dataset.add_frame({"state": torch.randn(1)})
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"task": "Dummy task"})
|
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -124,7 +114,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -133,7 +123,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -143,7 +133,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
ValueError,
|
ValueError,
|
||||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -155,7 +145,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact
|
|||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
|
dataset.add_frame({"state": 1.0}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -165,7 +155,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||||||
ValueError,
|
ValueError,
|
||||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -177,13 +167,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
|
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
@@ -195,7 +185,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2])
|
assert dataset[0]["state"].shape == torch.Size([2])
|
||||||
@@ -204,7 +194,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||||
@@ -213,7 +203,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||||
@@ -222,7 +212,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||||
@@ -231,7 +221,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||||
@@ -240,7 +230,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].ndim == 0
|
assert dataset[0]["state"].ndim == 0
|
||||||
@@ -249,7 +239,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["caption"] == "Dummy caption"
|
assert dataset[0]["caption"] == "Dummy caption"
|
||||||
@@ -264,7 +254,7 @@ def test_add_frame_image_wrong_shape(image_dataset):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
c, h, w = DUMMY_CHW
|
c, h, w = DUMMY_CHW
|
||||||
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
|
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_wrong_range(image_dataset):
|
def test_add_frame_image_wrong_range(image_dataset):
|
||||||
@@ -277,14 +267,14 @@ def test_add_frame_image_wrong_range(image_dataset):
|
|||||||
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
||||||
"""
|
"""
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image(image_dataset):
|
def test_add_frame_image(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -292,7 +282,7 @@ def test_add_frame_image(image_dataset):
|
|||||||
|
|
||||||
def test_add_frame_image_h_w_c(image_dataset):
|
def test_add_frame_image_h_w_c(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -301,7 +291,7 @@ def test_add_frame_image_h_w_c(image_dataset):
|
|||||||
def test_add_frame_image_uint8(image_dataset):
|
def test_add_frame_image_uint8(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
dataset.add_frame({"image": image}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -310,7 +300,7 @@ def test_add_frame_image_uint8(image_dataset):
|
|||||||
def test_add_frame_image_pil(image_dataset):
|
def test_add_frame_image_pil(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|||||||
Reference in New Issue
Block a user