chore: replace hard-coded next values with constants throughout all the source code (#2056)
This commit is contained in:
@@ -46,7 +46,7 @@ from lerobot.datasets.utils import (
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -399,8 +399,8 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
(OBS_STATE, 1, True),
|
||||
("next.reward", 0, False),
|
||||
("next.done", 0, False),
|
||||
(REWARD, 0, False),
|
||||
(DONE, 0, False),
|
||||
]
|
||||
|
||||
# test number of dimensions
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def test_binary_classifier_with_default_params():
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
@@ -58,7 +58,7 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
input = {
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
@@ -87,7 +87,7 @@ def test_multiclass_classifier():
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
config.num_classes = num_classes
|
||||
@@ -97,7 +97,7 @@ def test_multiclass_classifier():
|
||||
|
||||
input = {
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.rand((batch_size, num_classes)),
|
||||
REWARD: torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, REWARD, TRUNCATED
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
@@ -12,9 +12,9 @@ def _dummy_batch():
|
||||
f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
ACTION: torch.tensor([[0.5]]),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
REWARD: 1.0,
|
||||
DONE: False,
|
||||
TRUNCATED: False,
|
||||
"info": {"key": "value"},
|
||||
}
|
||||
|
||||
@@ -38,9 +38,9 @@ def test_observation_grouping_roundtrip():
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(batch_out[ACTION], batch_in[ACTION])
|
||||
assert batch_out["next.reward"] == batch_in["next.reward"]
|
||||
assert batch_out["next.done"] == batch_in["next.done"]
|
||||
assert batch_out["next.truncated"] == batch_in["next.truncated"]
|
||||
assert batch_out[REWARD] == batch_in[REWARD]
|
||||
assert batch_out[DONE] == batch_in[DONE]
|
||||
assert batch_out[TRUNCATED] == batch_in[TRUNCATED]
|
||||
assert batch_out["info"] == batch_in["info"]
|
||||
|
||||
|
||||
@@ -51,9 +51,9 @@ def test_batch_to_transition_observation_grouping():
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
REWARD: 1.5,
|
||||
DONE: True,
|
||||
TRUNCATED: False,
|
||||
"info": {"episode": 42},
|
||||
}
|
||||
|
||||
@@ -115,9 +115,9 @@ def test_transition_to_batch_observation_flattening():
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch[ACTION] == "action_data"
|
||||
assert batch["next.reward"] == 1.5
|
||||
assert batch["next.done"]
|
||||
assert not batch["next.truncated"]
|
||||
assert batch[REWARD] == 1.5
|
||||
assert batch[DONE]
|
||||
assert not batch[TRUNCATED]
|
||||
assert batch["info"] == {"episode": 42}
|
||||
|
||||
|
||||
@@ -125,9 +125,9 @@ def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
ACTION: torch.tensor([1.0, 2.0]),
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
REWARD: 2.0,
|
||||
DONE: False,
|
||||
TRUNCATED: True,
|
||||
"info": {"test": "no_obs"},
|
||||
}
|
||||
|
||||
@@ -146,9 +146,9 @@ def test_no_observation_keys():
|
||||
# Round trip should work
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0]))
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
assert reconstructed_batch[REWARD] == 2.0
|
||||
assert not reconstructed_batch[DONE]
|
||||
assert reconstructed_batch[TRUNCATED]
|
||||
assert reconstructed_batch["info"] == {"test": "no_obs"}
|
||||
|
||||
|
||||
@@ -173,9 +173,9 @@ def test_minimal_batch():
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch[OBS_STATE] == "minimal_state"
|
||||
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
assert reconstructed_batch[REWARD] == 0.0
|
||||
assert not reconstructed_batch[DONE]
|
||||
assert not reconstructed_batch[TRUNCATED]
|
||||
assert reconstructed_batch["info"] == {}
|
||||
|
||||
|
||||
@@ -197,9 +197,9 @@ def test_empty_batch():
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch[ACTION] is None
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
assert reconstructed_batch[REWARD] == 0.0
|
||||
assert not reconstructed_batch[DONE]
|
||||
assert not reconstructed_batch[TRUNCATED]
|
||||
assert reconstructed_batch["info"] == {}
|
||||
|
||||
|
||||
@@ -210,9 +210,9 @@ def test_complex_nested_observation():
|
||||
f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
OBS_STATE: torch.randn(7),
|
||||
ACTION: torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
REWARD: 3.14,
|
||||
DONE: False,
|
||||
TRUNCATED: True,
|
||||
"info": {"episode_length": 200, "success": True},
|
||||
}
|
||||
|
||||
@@ -240,9 +240,9 @@ def test_complex_nested_observation():
|
||||
assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION])
|
||||
|
||||
# Check other fields
|
||||
assert batch["next.reward"] == reconstructed_batch["next.reward"]
|
||||
assert batch["next.done"] == reconstructed_batch["next.done"]
|
||||
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
|
||||
assert batch[REWARD] == reconstructed_batch[REWARD]
|
||||
assert batch[DONE] == reconstructed_batch[DONE]
|
||||
assert batch[TRUNCATED] == reconstructed_batch[TRUNCATED]
|
||||
assert batch["info"] == reconstructed_batch["info"]
|
||||
|
||||
|
||||
@@ -267,13 +267,13 @@ def test_custom_converter():
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, 4),
|
||||
ACTION: torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
REWARD: 1.0,
|
||||
DONE: False,
|
||||
}
|
||||
|
||||
result = processor(batch)
|
||||
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert result[REWARD] == 2.0
|
||||
assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
|
||||
assert torch.allclose(result[ACTION], batch[ACTION])
|
||||
|
||||
@@ -9,7 +9,7 @@ from lerobot.processor.converters import (
|
||||
to_tensor,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, OBS_STR, REWARD
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
@@ -201,8 +201,8 @@ def test_batch_to_transition_with_index_fields():
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
ACTION: torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
REWARD: 1.5,
|
||||
DONE: False,
|
||||
"task": ["pick_cube"],
|
||||
"index": torch.tensor([42], dtype=torch.int64),
|
||||
"task_index": torch.tensor([3], dtype=torch.int64),
|
||||
|
||||
@@ -35,7 +35,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -258,9 +258,9 @@ def test_step_through_with_dict():
|
||||
batch = {
|
||||
OBS_IMAGE: None,
|
||||
ACTION: None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
REWARD: 0.0,
|
||||
DONE: False,
|
||||
TRUNCATED: False,
|
||||
"info": {},
|
||||
}
|
||||
|
||||
@@ -1843,9 +1843,9 @@ def test_save_load_with_custom_converter_functions():
|
||||
batch = {
|
||||
OBS_IMAGE: torch.randn(1, 3, 32, 32),
|
||||
ACTION: torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
"next.truncated": torch.tensor([False]),
|
||||
REWARD: torch.tensor([1.0]),
|
||||
DONE: torch.tensor([False]),
|
||||
TRUNCATED: torch.tensor([False]),
|
||||
"info": {},
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
@@ -380,9 +380,9 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == ACTION:
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
elif feature == REWARD:
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
elif feature == DONE:
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == OBS_IMAGE:
|
||||
# Tensor -> numpy is not precise, so we have some diff there
|
||||
|
||||
Reference in New Issue
Block a user