Fix policy construction (#1665)

* add: test to check proper construction with multiple features with STATE/ACTION type

* fix: robot and action state should match policy's expectations

* fix minor

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Francesco Capuano
2025-08-04 21:49:51 +02:00
committed by GitHub
parent 8c577525c1
commit 90d3a99aa1
2 changed files with 55 additions and 4 deletions

View File

@@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
@@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def robot_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.STATE:
for ft_name, ft in self.input_features.items():
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
return ft
return None
@@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def action_feature(self) -> PolicyFeature | None:
for _, ft in self.output_features.items():
if ft.type is FeatureType.ACTION:
for ft_name, ft in self.output_features.items():
if ft.type is FeatureType.ACTION and ft_name == ACTION:
return ft
return None