chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -11,7 +11,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -504,14 +504,14 @@ def test_features_basic():
|
||||
|
||||
input_features = {
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
PipelineFeatureType.ACTION: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that original features are preserved
|
||||
assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert "action" in output_features[PipelineFeatureType.ACTION]
|
||||
assert ACTION in output_features[PipelineFeatureType.ACTION]
|
||||
|
||||
# Check that tokenized features are added
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
Reference in New Issue
Block a user