Fix policy construction (#1665)

* add: test to check proper construction with multiple features with STATE/ACTION type

* fix: robot and action state should match policy's expectations

* fix minor

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Francesco Capuano
2025-08-04 21:49:51 +02:00
committed by GitHub
parent 8c577525c1
commit 90d3a99aa1
2 changed files with 55 additions and 4 deletions

View File

@@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin from lerobot.utils.hub import HubMixin
@@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property @property
def robot_state_feature(self) -> PolicyFeature | None: def robot_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items(): for ft_name, ft in self.input_features.items():
if ft.type is FeatureType.STATE: if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
return ft return ft
return None return None
@@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property @property
def action_feature(self) -> PolicyFeature | None: def action_feature(self) -> PolicyFeature | None:
for _, ft in self.output_features.items(): for ft_name, ft in self.output_features.items():
if ft.type is FeatureType.ACTION: if ft.type is FeatureType.ACTION and ft_name == ACTION:
return ft return ft
return None return None

View File

@@ -27,11 +27,13 @@ from lerobot import available_policies
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.datasets.factory import make_dataset from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.datasets.utils import cycle, dataset_to_policy_features
from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation from lerobot.envs.utils import preprocess_observation
from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.policies.factory import ( from lerobot.policies.factory import (
get_policy_class, get_policy_class,
@@ -363,6 +365,54 @@ def test_normalize(insert_temporal_dim):
unnormalize(output_batch) unnormalize(output_batch)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""
Asserts that multiple keys with type State/Action are correctly processed by the policy constructor,
preventing erroneous creation of the policy object.
"""
input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
}
if multikey:
"""Simulates the complete state/action is constructed from more granular multiple
keys, of the same type as the overall state/action"""
input_features = {}
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
output_features["action"] = PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
)
config = ACTConfig(input_features=input_features, output_features=output_features)
state_condition = config.robot_state_feature == input_features[OBS_STATE]
action_condition = config.action_feature == output_features[ACTION]
assert state_condition, (
f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}"
)
assert action_condition, (
f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ds_repo_id, policy_name, policy_kwargs, file_name_extra", "ds_repo_id, policy_name, policy_kwargs, file_name_extra",
[ [