chore: replace hard-coded next values with constants throughout all the source code (#2056)

This commit is contained in:
Steven Palma
2025-09-26 14:30:07 +02:00
committed by GitHub
parent ec40ccde0d
commit c5b5955c5a
13 changed files with 87 additions and 86 deletions

View File

@@ -19,7 +19,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import require_package
@@ -45,7 +45,7 @@ def test_binary_classifier_with_default_params():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
@@ -58,7 +58,7 @@ def test_binary_classifier_with_default_params():
input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
images, labels = classifier.extract_images_and_labels(input)
@@ -87,7 +87,7 @@ def test_multiclass_classifier():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
}
config.num_cameras = 1
config.num_classes = num_classes
@@ -97,7 +97,7 @@ def test_multiclass_classifier():
input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
REWARD: torch.rand((batch_size, num_classes)),
}
images, labels = classifier.extract_images_and_labels(input)