fix(async): Add pre and post processing to async inference and update docs (#2132)
* Add pre and post processing to async inference and update docs * precommit fix typo * fix tests * refactor(async): no None branching for processors in _predict_action_chunk --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -91,6 +91,9 @@ def test_async_inference_e2e(monkeypatch):
|
||||
policy_server.policy = MockPolicy()
|
||||
policy_server.actions_per_chunk = 20
|
||||
policy_server.device = "cpu"
|
||||
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
|
||||
policy_server.preprocessor = lambda obs: obs
|
||||
policy_server.postprocessor = lambda tensor: tensor
|
||||
|
||||
# Set up robot config and features
|
||||
robot_config = MockRobotConfig()
|
||||
|
||||
@@ -333,9 +333,8 @@ def test_raw_observation_to_observation_basic():
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert OBS_STATE in observation
|
||||
@@ -345,7 +344,6 @@ def test_raw_observation_to_observation_basic():
|
||||
# Check state processing
|
||||
state = observation[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.device.type == device
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
@@ -356,10 +354,6 @@ def test_raw_observation_to_observation_basic():
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
assert phone_img.shape == (1, 3, 160, 160)
|
||||
|
||||
# Check device placement
|
||||
assert laptop_img.device.type == device
|
||||
assert phone_img.device.type == device
|
||||
|
||||
# Check image dtype and range (should be float32 in [0, 1])
|
||||
assert laptop_img.dtype == torch.float32
|
||||
assert phone_img.dtype == torch.float32
|
||||
@@ -374,9 +368,8 @@ def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that task string is preserved
|
||||
assert "task" in observation
|
||||
@@ -386,19 +379,17 @@ def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are properly moved to the specified device."""
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
"""Test that tensors are created (device placement is handled by preprocessor)."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that all tensors are on the correct device
|
||||
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
||||
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
@@ -406,11 +397,10 @@ def test_raw_observation_to_observation_deterministic():
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
# Run twice with same input
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Results should be identical
|
||||
assert set(obs1.keys()) == set(obs2.keys())
|
||||
@@ -448,7 +438,7 @@ def test_image_processing_pipeline_preserves_content():
|
||||
)
|
||||
}
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
|
||||
@@ -196,6 +196,9 @@ def test_predict_action_chunk(monkeypatch, policy_server):
|
||||
|
||||
# Force server to act-style policy; patch method to return deterministic tensor
|
||||
policy_server.policy_type = "act"
|
||||
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
|
||||
policy_server.preprocessor = lambda obs: obs
|
||||
policy_server.postprocessor = lambda tensor: tensor
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
Reference in New Issue
Block a user