chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -39,8 +39,8 @@ def test_process_single_image():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
processed_img = processed_obs["observation.image"]
assert OBS_IMAGE in processed_obs
processed_img = processed_obs[OBS_IMAGE]
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
assert processed_img.shape == (1, 3, 64, 64)
@@ -66,12 +66,12 @@ def test_process_image_dict():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
assert "observation.images.camera2" in processed_obs
assert f"{OBS_IMAGES}.camera1" in processed_obs
assert f"{OBS_IMAGES}.camera2" in processed_obs
# Check shapes
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32)
assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48)
def test_process_batched_image():
@@ -88,7 +88,7 @@ def test_process_batched_image():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64)
def test_invalid_image_format():
@@ -173,10 +173,10 @@ def test_process_environment_state():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert OBS_ENV_STATE in processed_obs
assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"]
processed_state = processed_obs[OBS_ENV_STATE]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
@@ -194,10 +194,10 @@ def test_process_agent_pos():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert OBS_STATE in processed_obs
assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"]
processed_state = processed_obs[OBS_STATE]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
@@ -217,8 +217,8 @@ def test_process_batched_states():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
assert processed_obs[OBS_ENV_STATE].shape == (2, 2)
assert processed_obs[OBS_STATE].shape == (2, 2)
def test_process_both_states():
@@ -235,8 +235,8 @@ def test_process_both_states():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
assert OBS_ENV_STATE in processed_obs
assert OBS_STATE in processed_obs
# Check that original keys were removed
assert "environment_state" not in processed_obs
@@ -281,12 +281,12 @@ def test_complete_observation_processing():
processed_obs = result[TransitionKey.OBSERVATION]
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
assert OBS_IMAGE in processed_obs
assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32)
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
assert OBS_ENV_STATE in processed_obs
assert OBS_STATE in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
@@ -308,7 +308,7 @@ def test_image_only_processing():
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.image" in processed_obs
assert OBS_IMAGE in processed_obs
assert len(processed_obs) == 1
@@ -323,7 +323,7 @@ def test_state_only_processing():
result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert OBS_STATE in processed_obs
assert "agent_pos" not in processed_obs
@@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
proc = VanillaObservationProcessorStep()
features = {
PipelineFeatureType.OBSERVATION: {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
},
}
@@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
assert (
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
== features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
)
assert (
OBS_STATE in out[PipelineFeatureType.OBSERVATION]