fix(async): Add pre and post processing to async inference and update docs (#2132)
* Add pre and post processing to async inference and update docs * precommit fix typo * fix tests * refactor(async): no None branching for processors in _predict_action_chunk --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -333,9 +333,8 @@ def test_raw_observation_to_observation_basic():
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert OBS_STATE in observation
|
||||
@@ -345,7 +344,6 @@ def test_raw_observation_to_observation_basic():
|
||||
# Check state processing
|
||||
state = observation[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.device.type == device
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
@@ -356,10 +354,6 @@ def test_raw_observation_to_observation_basic():
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
assert phone_img.shape == (1, 3, 160, 160)
|
||||
|
||||
# Check device placement
|
||||
assert laptop_img.device.type == device
|
||||
assert phone_img.device.type == device
|
||||
|
||||
# Check image dtype and range (should be float32 in [0, 1])
|
||||
assert laptop_img.dtype == torch.float32
|
||||
assert phone_img.dtype == torch.float32
|
||||
@@ -374,9 +368,8 @@ def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that task string is preserved
|
||||
assert "task" in observation
|
||||
@@ -386,19 +379,17 @@ def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are properly moved to the specified device."""
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
"""Test that tensors are created (device placement is handled by preprocessor)."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all tensors are on the correct device
|
||||
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
||||
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
@@ -406,11 +397,10 @@ def test_raw_observation_to_observation_deterministic():
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
# Run twice with same input
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Results should be identical
|
||||
assert set(obs1.keys()) == set(obs2.keys())
|
||||
@@ -448,7 +438,7 @@ def test_image_processing_pipeline_preserves_content():
|
||||
)
|
||||
}
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
|
||||
Reference in New Issue
Block a user