diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index c5b2fa09..f5fa727c 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.hub import HubMixin @@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def robot_state_feature(self) -> PolicyFeature | None: - for _, ft in self.input_features.items(): - if ft.type is FeatureType.STATE: + for ft_name, ft in self.input_features.items(): + if ft.type is FeatureType.STATE and ft_name == OBS_STATE: return ft return None @@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def action_feature(self) -> PolicyFeature | None: - for _, ft in self.output_features.items(): - if ft.type is FeatureType.ACTION: + for ft_name, ft in self.output_features.items(): + if ft.type is FeatureType.ACTION and ft_name == ACTION: return ft return None diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ed37fedd..da7573d7 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -27,11 +27,13 @@ from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.utils import preprocess_observation from lerobot.optim.factory import make_optimizer_and_scheduler +from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.policies.factory import ( get_policy_class, @@ -363,6 +365,54 @@ def test_normalize(insert_temporal_dim): unnormalize(output_batch) +@pytest.mark.parametrize("multikey", [True, False]) +def test_multikey_construction(multikey: bool): + """ + Asserts that multiple keys with type State/Action are correctly processed by the policy constructor, + preventing erroneous creation of the policy object. + """ + input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(10,), + ), + } + output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(5,), + ), + } + + if multikey: + """Simulates the complete state/action is constructed from more granular multiple + keys, of the same type as the overall state/action""" + input_features = {} + input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) + + output_features = {} + output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) + output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,)) + output_features["action"] = PolicyFeature( + type=FeatureType.ACTION, + shape=(5,), + ) + + config = ACTConfig(input_features=input_features, output_features=output_features) + + state_condition = config.robot_state_feature == input_features[OBS_STATE] + action_condition = config.action_feature == output_features[ACTION] + + assert state_condition, ( + f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}" + ) + assert action_condition, ( + f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}" + ) + + @pytest.mark.parametrize( "ds_repo_id, policy_name, policy_kwargs, file_name_extra", [