forked from tangger/lerobot
Fix policy construction (#1665)
* add: test to check proper construction with multiple features with STATE/ACTION type * fix: robot and action state should match policy's expectations * fix minor Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8c577525c1
commit
90d3a99aa1
@@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
@@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
|
||||
@property
|
||||
def robot_state_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE:
|
||||
for ft_name, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
|
||||
return ft
|
||||
return None
|
||||
|
||||
@@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION:
|
||||
for ft_name, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION and ft_name == ACTION:
|
||||
return ft
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user