chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -39,8 +39,8 @@ def test_process_single_image():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert "observation.image" in processed_obs
|
||||
processed_img = processed_obs["observation.image"]
|
||||
assert OBS_IMAGE in processed_obs
|
||||
processed_img = processed_obs[OBS_IMAGE]
|
||||
|
||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
||||
assert processed_img.shape == (1, 3, 64, 64)
|
||||
@@ -66,12 +66,12 @@ def test_process_image_dict():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both images were processed
|
||||
assert "observation.images.camera1" in processed_obs
|
||||
assert "observation.images.camera2" in processed_obs
|
||||
assert f"{OBS_IMAGES}.camera1" in processed_obs
|
||||
assert f"{OBS_IMAGES}.camera2" in processed_obs
|
||||
|
||||
# Check shapes
|
||||
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
|
||||
assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48)
|
||||
|
||||
|
||||
def test_process_batched_image():
|
||||
@@ -88,7 +88,7 @@ def test_process_batched_image():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
||||
assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64)
|
||||
|
||||
|
||||
def test_invalid_image_format():
|
||||
@@ -173,10 +173,10 @@ def test_process_environment_state():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.environment_state"]
|
||||
processed_state = processed_obs[OBS_ENV_STATE]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
@@ -194,10 +194,10 @@ def test_process_agent_pos():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.state"]
|
||||
processed_state = processed_obs[OBS_STATE]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
|
||||
@@ -217,8 +217,8 @@ def test_process_batched_states():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
assert processed_obs["observation.state"].shape == (2, 2)
|
||||
assert processed_obs[OBS_ENV_STATE].shape == (2, 2)
|
||||
assert processed_obs[OBS_STATE].shape == (2, 2)
|
||||
|
||||
|
||||
def test_process_both_states():
|
||||
@@ -235,8 +235,8 @@ def test_process_both_states():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "environment_state" not in processed_obs
|
||||
@@ -281,12 +281,12 @@ def test_complete_observation_processing():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32)
|
||||
|
||||
# Check that states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "pixels" not in processed_obs
|
||||
@@ -308,7 +308,7 @@ def test_image_only_processing():
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ def test_state_only_processing():
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
|
||||
@@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
},
|
||||
}
|
||||
@@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
|
||||
== features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
Reference in New Issue
Block a user