chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -23,6 +23,7 @@ from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
try:
@@ -85,14 +86,14 @@ def test_sac_policy_with_default_args():
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.state": torch.randn(batch_size, state_dim),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.image": torch.randn(batch_size, 3, 84, 84),
"observation.state": torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
}
@@ -126,14 +127,14 @@ def create_train_batch_with_visual_input(
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
"observation.image": torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
}
@@ -180,10 +181,10 @@ def create_default_config(
action_dim += 1
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
"observation.state": {
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
@@ -205,8 +206,8 @@ def create_config_with_visual_input(
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats["observation.image"] = {
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}