chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -46,6 +46,7 @@ from lerobot.datasets.utils import (
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -75,7 +76,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
# Instantiate both ways
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(
|
||||
@@ -397,7 +398,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
(OBS_STATE, 1, True),
|
||||
("next.reward", 0, False),
|
||||
("next.done", 0, False),
|
||||
]
|
||||
@@ -662,7 +663,7 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata."""
|
||||
features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
@@ -769,7 +770,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_update_chunk_settings_video_dataset(tmp_path):
|
||||
"""Test update_chunk_settings with a video dataset to ensure video-specific logic works."""
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
f"{OBS_IMAGES}.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
|
||||
Reference in New Issue
Block a user