diff --git a/docs/source/async.mdx b/docs/source/async.mdx index c66cdb14..be10f8ba 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -31,15 +31,15 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif You can spin up a policy server running: ```shell -python src/lerobot/async_inference/policy_server.py \ - --host=127.0.0.1 \ - --port=8080 \ +python -m lerobot.async_inference.policy_server \ + --host=127.0.0.1 \ + --port=8080 ``` This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with: ```shell -python src/lerobot/async_inference/robot_client.py \ +python -m lerobot.async_inference.robot_client \ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --robot.type=so100_follower \ # ROBOT: your robot type --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port @@ -113,9 +113,9 @@ As such, spinning up a policy server is as easy as specifying the host address a ```bash -python -m lerobot.scripts.server.policy_server \ - --host="localhost" \ - --port=8080 +python -m lerobot.async_inference.policy_server \ + --host=127.0.0.1 \ + --port=8080 ``` @@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio ```bash -python src/lerobot/async_inference/robot_client.py \ +python -m lerobot.async_inference.robot_client \ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --robot.type=so100_follower \ # ROBOT: your robot type --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index 5ebf3780..1b1dac0f 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2 SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] # TODO: Add all other robots -SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"] +SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"] diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 88fb00a3..54fad8c5 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -92,11 +92,11 @@ def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, return resized.squeeze(0) +# TODO(Steven): Consider implementing a pipeline step for this def raw_observation_to_observation( raw_observation: RawObservation, lerobot_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature], - device: str, ) -> Observation: observation = {} @@ -105,9 +105,7 @@ def raw_observation_to_observation( if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations if "image" in k: # Policy expects images in shape (B, C, H, W) - observation[k] = prepare_image(v).unsqueeze(0).to(device) - else: - observation[k] = v.to(device) + observation[k] = prepare_image(v).unsqueeze(0) else: observation[k] = v diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index 12572706..f7e00dea 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -15,7 +15,7 @@ """ Example: ```shell -python src/lerobot/async_inference/policy_server.py \ +python -m lerobot.async_inference.policy_server \ --host=127.0.0.1 \ --port=8080 \ --fps=30 \ @@ -32,12 +32,17 @@ from concurrent import futures from dataclasses import asdict from pprint import pformat from queue import Empty, Queue +from typing import Any import draccus import grpc import torch -from lerobot.policies.factory import get_policy_class +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.processor import ( + PolicyAction, + PolicyProcessorPipeline, +) from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore @@ -82,6 +87,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): self.lerobot_features = None self.actions_per_chunk = None self.policy = None + self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None + self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None @property def running(self): @@ -146,6 +153,16 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): start = time.perf_counter() self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) self.policy.to(self.device) + + # Load preprocessor and postprocessor, overriding device to match requested device + device_override = {"device": self.device} + self.preprocessor, self.postprocessor = make_pre_post_processors( + self.policy.config, + pretrained_path=policy_specs.pretrained_name_or_path, + preprocessor_overrides={"device_processor": device_override}, + postprocessor_overrides={"device_processor": device_override}, + ) + end = time.perf_counter() self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") @@ -173,7 +190,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): # Calculate FPS metrics fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp) - self.logger.info( + self.logger.debug( f"Received observation #{obs_timestep} | " f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client f"Target: {fps_metrics['target_fps']:.2f} | " @@ -189,7 +206,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): if not self._enqueue_observation( timed_observation # wrapping a RawObservation ): - self.logger.info(f"Observation #{obs_timestep} has been filtered out") + self.logger.debug(f"Observation #{obs_timestep} has been filtered out") return services_pb2.Empty() @@ -301,23 +318,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): for i, action in enumerate(action_chunk) ] - def _prepare_observation(self, observation_t: TimedObservation) -> Observation: - """ - Prepare observation, ready for policy inference. - E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the - client and then convert them to float32 [0,1] images here, before running inference. - """ - # RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape - observation: Observation = raw_observation_to_observation( - observation_t.get_observation(), - self.lerobot_features, - self.policy_image_features, - self.device, - ) - # processed Observation - right keys, right dtype, right image shape - - return observation - def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: """Get an action chunk from the policy. The chunk contains only""" chunk = self.policy.predict_action_chunk(observation) @@ -327,44 +327,76 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): return chunk[:, : self.actions_per_chunk, :] def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: - """Predict an action chunk based on an observation""" - inference_starts = time.perf_counter() + """Predict an action chunk based on an observation. + Pipeline: + 1. Convert raw observation to LeRobot format + 2. Apply preprocessor (tokenization, normalization, batching, device placement) + 3. Run policy inference to get action chunk + 4. Apply postprocessor (unnormalization, device movement) + 5. Convert to TimedAction list + """ """1. Prepare observation""" - start_time = time.perf_counter() - observation = self._prepare_observation(observation_t) - preprocessing_time = time.perf_counter() - start_time + start_prepare = time.perf_counter() + observation: Observation = raw_observation_to_observation( + observation_t.get_observation(), + self.lerobot_features, + self.policy_image_features, + ) + prepare_time = time.perf_counter() - start_prepare + """2. Apply preprocessor""" + start_preprocess = time.perf_counter() + observation = self.preprocessor(observation) self.last_processed_obs: TimedObservation = observation_t + preprocessing_time = time.perf_counter() - start_preprocess - """2. Get action chunk""" - start_time = time.perf_counter() + """3. Get action chunk""" + start_inference = time.perf_counter() action_tensor = self._get_action_chunk(observation) - inference_time = time.perf_counter() - start_time + inference_time = time.perf_counter() - start_inference + self.logger.info( + f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}" + ) - """3. Post-inference processing""" - start_time = time.perf_counter() - # Move to CPU before serializing - action_tensor = action_tensor.cpu().squeeze(0) + """4. Apply postprocessor""" + # Apply postprocessor (handles unnormalization and device movement) + # Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim) + # So we process each action in the chunk individually + start_postprocess = time.perf_counter() + _, chunk_size, _ = action_tensor.shape + # Process each action in the chunk + processed_actions = [] + for i in range(chunk_size): + # Extract action at timestep i: (B, action_dim) + single_action = action_tensor[:, i, :] + processed_action = self.postprocessor(single_action) + processed_actions.append(processed_action) + + # Stack back to (B, chunk_size, action_dim), then remove batch dim + action_tensor = torch.stack(processed_actions, dim=1).squeeze(0) + self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}") + + """5. Convert to TimedAction list""" action_chunk = self._time_action_chunk( observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() ) - postprocessing_time = time.perf_counter() - start_time - inference_stops = time.perf_counter() + postprocess_stops = time.perf_counter() + postprocessing_time = postprocess_stops - start_postprocess self.logger.info( - f"Observation {observation_t.get_timestep()} |" - f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms" + f"Observation {observation_t.get_timestep()} | " + f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms" ) - # full-process latency breakdown for debugging purposes self.logger.debug( f"Observation {observation_t.get_timestep()} | " - f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | " - f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | " - f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | " - f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms" + f"Prepare time: {1000 * prepare_time:.2f}ms | " + f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | " + f"Inference time: {1000 * inference_time:.2f}ms | " + f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | " + f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms" ) return action_chunk diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index c969bc60..8c4425c6 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -52,6 +52,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_so100_follower, koch_follower, make_robot_from_config, so100_follower, @@ -214,7 +215,7 @@ class RobotClient: ) _ = self.stub.SendObservations(observation_iterator) obs_timestep = obs.get_timestep() - self.logger.info(f"Sent observation #{obs_timestep} | ") + self.logger.debug(f"Sent observation #{obs_timestep} | ") return True @@ -467,7 +468,7 @@ class RobotClient: if self._ready_to_send_observation(): _captured_observation = self.control_loop_observation(task, verbose) - self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}") + self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}") # Dynamically adjust sleep time to maintain the desired control frequency time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start))) diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index 2689f061..ebaef2ef 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -91,6 +91,9 @@ def test_async_inference_e2e(monkeypatch): policy_server.policy = MockPolicy() policy_server.actions_per_chunk = 20 policy_server.device = "cpu" + # NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix. + policy_server.preprocessor = lambda obs: obs + policy_server.postprocessor = lambda tensor: tensor # Set up robot config and features robot_config = MockRobotConfig() diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index acf5870d..1e2d1e31 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -333,9 +333,8 @@ def test_raw_observation_to_observation_basic(): robot_obs = _create_mock_robot_observation() lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - device = "cpu" - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) # Check that all expected keys are present assert OBS_STATE in observation @@ -345,7 +344,6 @@ def test_raw_observation_to_observation_basic(): # Check state processing state = observation[OBS_STATE] assert isinstance(state, torch.Tensor) - assert state.device.type == device assert state.shape == (1, 4) # Batched # Check image processing @@ -356,10 +354,6 @@ def test_raw_observation_to_observation_basic(): assert laptop_img.shape == (1, 3, 224, 224) assert phone_img.shape == (1, 3, 160, 160) - # Check device placement - assert laptop_img.device.type == device - assert phone_img.device.type == device - # Check image dtype and range (should be float32 in [0, 1]) assert laptop_img.dtype == torch.float32 assert phone_img.dtype == torch.float32 @@ -374,9 +368,8 @@ def test_raw_observation_to_observation_with_non_tensor_data(): lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - device = "cpu" - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) # Check that task string is preserved assert "task" in observation @@ -386,19 +379,17 @@ def test_raw_observation_to_observation_with_non_tensor_data(): @torch.no_grad() def test_raw_observation_to_observation_device_handling(): - """Test that tensors are properly moved to the specified device.""" - device = "mps" if torch.backends.mps.is_available() else "cpu" - + """Test that tensors are created (device placement is handled by preprocessor).""" robot_obs = _create_mock_robot_observation() lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) - # Check that all tensors are on the correct device + # Check that all expected keys produce tensors (device placement handled by preprocessor later) for key, value in observation.items(): if isinstance(value, torch.Tensor): - assert value.device.type == device, f"Tensor {key} not on {device}" + assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device" def test_raw_observation_to_observation_deterministic(): @@ -406,11 +397,10 @@ def test_raw_observation_to_observation_deterministic(): robot_obs = _create_mock_robot_observation() lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - device = "cpu" # Run twice with same input - obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) - obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) + obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) # Results should be identical assert set(obs1.keys()) == set(obs2.keys()) @@ -448,7 +438,7 @@ def test_image_processing_pipeline_preserves_content(): ) } - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index de441ff0..29583d4f 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -196,6 +196,9 @@ def test_predict_action_chunk(monkeypatch, policy_server): # Force server to act-style policy; patch method to return deterministic tensor policy_server.policy_type = "act" + # NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix. + policy_server.preprocessor = lambda obs: obs + policy_server.postprocessor = lambda tensor: tensor action_dim = 6 batch_size = 1 actions_per_chunk = policy_server.actions_per_chunk