chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -19,6 +19,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE
from tests.utils import require_package
@@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params():
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
@@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
@@ -83,7 +84,7 @@ def test_multiclass_classifier():
num_classes = 5
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
@@ -95,7 +96,7 @@ def test_multiclass_classifier():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
}