chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -23,6 +23,7 @@ from torch import Tensor, nn
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
try:
|
||||
@@ -85,14 +86,14 @@ def test_sac_policy_with_default_args():
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
@@ -126,14 +127,14 @@ def create_train_batch_with_visual_input(
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
@@ -180,10 +181,10 @@ def create_default_config(
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
@@ -205,8 +206,8 @@ def create_config_with_visual_input(
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats["observation.image"] = {
|
||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats[OBS_IMAGE] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user