feat(scripts): Introduce build_inference_frame/make_robot_action util to easily allow API-based Inference (#2143)

* fix: expose a function explicitly building a frame for inference

* fix: first make dataset frame, then make ready for inference

* fix: reducing reliance on lerobot record for policy's ouptuts too

* fix: encapsulating squeezing out + device handling from predict action

* fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion)

* fix(policies): right utils signature + docstrings (#2198)

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Francesco Capuano
2025-10-14 15:47:32 +02:00
committed by GitHub
parent bf6ac5e110
commit 723013c71b
3 changed files with 117 additions and 21 deletions

View File

@@ -16,10 +16,16 @@
import logging import logging
from collections import deque from collections import deque
from typing import Any
import numpy as np
import torch import torch
from torch import nn from torch import nn
from lerobot.datasets.utils import build_dataset_frame
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
def populate_queues( def populate_queues(
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
@@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str])
logging.warning(f"Missing key(s) when loading model: {missing_keys}") logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys: if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}") logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
# TODO(Steven): Move this function to a proper preprocessor step
def prepare_observation_for_inference(
observation: dict[str, np.ndarray],
device: torch.device,
task: str | None = None,
robot_type: str | None = None,
) -> RobotObservation:
"""Converts observation data to model-ready PyTorch tensors.
This function takes a dictionary of NumPy arrays, performs necessary
preprocessing, and prepares it for model inference. The steps include:
1. Converting NumPy arrays to PyTorch tensors.
2. Normalizing and permuting image data (if any).
3. Adding a batch dimension to each tensor.
4. Moving all tensors to the specified compute device.
5. Adding task and robot type information to the dictionary.
Args:
observation: A dictionary mapping observation names (str) to NumPy
array data. For images, the format is expected to be (H, W, C).
device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the
tensors will be moved.
task: An optional string identifier for the current task.
robot_type: An optional string identifier for the robot being used.
Returns:
A dictionary where values are PyTorch tensors preprocessed for
inference, residing on the target device. Image tensors are reshaped
to (C, H, W) and normalized to a [0, 1] range.
"""
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
return observation
def build_inference_frame(
observation: dict[str, Any],
device: torch.device,
ds_features: dict[str, dict],
task: str | None = None,
robot_type: str | None = None,
) -> RobotObservation:
"""Constructs a model-ready observation tensor dict from a raw observation.
This utility function orchestrates the process of converting a raw,
unstructured observation from an environment into a structured,
tensor-based format suitable for passing to a policy model.
Args:
observation: The raw observation dictionary, which may contain
superfluous keys.
device: The target PyTorch device for the final tensors.
ds_features: A configuration dictionary that specifies which features
to extract from the raw observation.
task: An optional string identifier for the current task.
robot_type: An optional string identifier for the robot being used.
Returns:
A dictionary of preprocessed tensors ready for model inference.
"""
# Extracts the correct keys from the incoming raw observation
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)
# Performs the necessary conversions to the observation
observation = prepare_observation_for_inference(observation, device, task, robot_type)
return observation
def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction:
"""Converts a policy's output tensor into a dictionary of named actions.
This function translates the numerical output from a policy model into a
human-readable and robot-consumable format, where each dimension of the
action tensor is mapped to a named motor or actuator command.
Args:
action_tensor: A PyTorch tensor representing the policy's action,
typically with a batch dimension (e.g., shape [1, action_dim]).
ds_features: A configuration dictionary containing metadata, including
the names corresponding to each index of the action tensor.
Returns:
A dictionary mapping action names (e.g., "joint_1_motor") to their
corresponding floating-point values, ready to be sent to a robot
controller.
"""
# TODO(Steven): Check if these steps are already in all postprocessor policies
action_tensor = action_tensor.squeeze(0)
action_tensor = action_tensor.to("cpu")
action_names = ds_features[ACTION]["names"]
act_processed_policy: RobotAction = {
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
}
return act_processed_policy

View File

@@ -79,6 +79,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import ( from lerobot.processor import (
PolicyAction, PolicyAction,
PolicyProcessorPipeline, PolicyProcessorPipeline,
@@ -316,10 +317,7 @@ def record_loop(
robot_type=robot.robot_type, robot_type=robot.robot_type,
) )
action_names = dataset.features[ACTION]["names"] act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
act_processed_policy: RobotAction = {
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
}
elif policy is None and isinstance(teleop, Teleoperator): elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action() act = teleop.get_action()

View File

@@ -31,6 +31,7 @@ from deepdiff import DeepDiff
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.robots import Robot from lerobot.robots import Robot
@@ -102,17 +103,7 @@ def predict_action(
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
): ):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation: observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
observation = preprocessor(observation) observation = preprocessor(observation)
# Compute the next action with the policy # Compute the next action with the policy
@@ -121,12 +112,6 @@ def predict_action(
action = postprocessor(action) action = postprocessor(action)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action return action