chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -19,6 +19,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE
from tests.utils import require_package
@@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params():
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
@@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
@@ -83,7 +84,7 @@ def test_multiclass_classifier():
num_classes = 5
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
@@ -95,7 +96,7 @@ def test_multiclass_classifier():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
}

View File

@@ -41,7 +41,7 @@ from lerobot.policies.factory import (
make_pre_post_processors,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -52,7 +52,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
# Create only one camera input which is squared to fit all current policy constraints
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
camera_features = {
"observation.images.laptop": {
f"{OBS_IMAGES}.laptop": {
"shape": (84, 84, 3),
"names": ["height", "width", "channels"],
"info": None,
@@ -64,7 +64,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"observation.state": {
OBS_STATE: {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
@@ -281,7 +281,7 @@ def test_multikey_construction(multikey: bool):
preventing erroneous creation of the policy object.
"""
input_features = {
"observation.state": PolicyFeature(
OBS_STATE: PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
@@ -297,9 +297,9 @@ def test_multikey_construction(multikey: bool):
"""Simulates the complete state/action is constructed from more granular multiple
keys, of the same type as the overall state/action"""
input_features = {}
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))

View File

@@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
@@ -37,11 +38,11 @@ def test_sac_config_default_initialization():
"ACTION": NormalizationMode.MIN_MAX,
}
assert config.dataset_stats == {
"observation.image": {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
@@ -90,11 +91,11 @@ def test_sac_config_default_initialization():
# Dataset stats defaults
expected_dataset_stats = {
"observation.image": {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
@@ -191,7 +192,7 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
@@ -210,7 +211,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):

View File

@@ -23,6 +23,7 @@ from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
try:
@@ -85,14 +86,14 @@ def test_sac_policy_with_default_args():
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.state": torch.randn(batch_size, state_dim),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.image": torch.randn(batch_size, 3, 84, 84),
"observation.state": torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
}
@@ -126,14 +127,14 @@ def create_train_batch_with_visual_input(
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
"observation.image": torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
}
@@ -180,10 +181,10 @@ def create_default_config(
action_dim += 1
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
"observation.state": {
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
@@ -205,8 +206,8 @@ def create_config_with_visual_input(
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats["observation.image"] = {
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}