Fix policy construction (#1665)
* add: test to check proper construction with multiple features with STATE/ACTION type * fix: robot and action state should match policy's expectations * fix minor Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8c577525c1
commit
90d3a99aa1
@@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME
|
|||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.constants import ACTION, OBS_STATE
|
||||||
from lerobot.optim.optimizers import OptimizerConfig
|
from lerobot.optim.optimizers import OptimizerConfig
|
||||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.utils.hub import HubMixin
|
from lerobot.utils.hub import HubMixin
|
||||||
@@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def robot_state_feature(self) -> PolicyFeature | None:
|
def robot_state_feature(self) -> PolicyFeature | None:
|
||||||
for _, ft in self.input_features.items():
|
for ft_name, ft in self.input_features.items():
|
||||||
if ft.type is FeatureType.STATE:
|
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
|
||||||
return ft
|
return ft
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_feature(self) -> PolicyFeature | None:
|
def action_feature(self) -> PolicyFeature | None:
|
||||||
for _, ft in self.output_features.items():
|
for ft_name, ft in self.output_features.items():
|
||||||
if ft.type is FeatureType.ACTION:
|
if ft.type is FeatureType.ACTION and ft_name == ACTION:
|
||||||
return ft
|
return ft
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -27,11 +27,13 @@ from lerobot import available_policies
|
|||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.constants import ACTION, OBS_STATE
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||||
from lerobot.envs.factory import make_env, make_env_config
|
from lerobot.envs.factory import make_env, make_env_config
|
||||||
from lerobot.envs.utils import preprocess_observation
|
from lerobot.envs.utils import preprocess_observation
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
||||||
from lerobot.policies.factory import (
|
from lerobot.policies.factory import (
|
||||||
get_policy_class,
|
get_policy_class,
|
||||||
@@ -363,6 +365,54 @@ def test_normalize(insert_temporal_dim):
|
|||||||
unnormalize(output_batch)
|
unnormalize(output_batch)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("multikey", [True, False])
|
||||||
|
def test_multikey_construction(multikey: bool):
|
||||||
|
"""
|
||||||
|
Asserts that multiple keys with type State/Action are correctly processed by the policy constructor,
|
||||||
|
preventing erroneous creation of the policy object.
|
||||||
|
"""
|
||||||
|
input_features = {
|
||||||
|
"observation.state": PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(10,),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
output_features = {
|
||||||
|
"action": PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(5,),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if multikey:
|
||||||
|
"""Simulates the complete state/action is constructed from more granular multiple
|
||||||
|
keys, of the same type as the overall state/action"""
|
||||||
|
input_features = {}
|
||||||
|
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||||
|
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||||
|
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||||
|
|
||||||
|
output_features = {}
|
||||||
|
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
|
||||||
|
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
|
||||||
|
output_features["action"] = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(5,),
|
||||||
|
)
|
||||||
|
|
||||||
|
config = ACTConfig(input_features=input_features, output_features=output_features)
|
||||||
|
|
||||||
|
state_condition = config.robot_state_feature == input_features[OBS_STATE]
|
||||||
|
action_condition = config.action_feature == output_features[ACTION]
|
||||||
|
|
||||||
|
assert state_condition, (
|
||||||
|
f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}"
|
||||||
|
)
|
||||||
|
assert action_condition, (
|
||||||
|
f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
|
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user