chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -41,7 +41,7 @@ from lerobot.policies.factory import (
|
||||
make_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
@@ -52,7 +52,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
# Create only one camera input which is squared to fit all current policy constraints
|
||||
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
|
||||
camera_features = {
|
||||
"observation.images.laptop": {
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"shape": (84, 84, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
@@ -64,7 +64,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
@@ -281,7 +281,7 @@ def test_multikey_construction(multikey: bool):
|
||||
preventing erroneous creation of the policy object.
|
||||
"""
|
||||
input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
OBS_STATE: PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
@@ -297,9 +297,9 @@ def test_multikey_construction(multikey: bool):
|
||||
"""Simulates the complete state/action is constructed from more granular multiple
|
||||
keys, of the same type as the overall state/action"""
|
||||
input_features = {}
|
||||
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
|
||||
output_features = {}
|
||||
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
|
||||
|
||||
Reference in New Issue
Block a user