feat(scripts): Introduce build_inference_frame/make_robot_action util to easily allow API-based Inference (#2143)
* fix: expose a function explicitly building a frame for inference * fix: first make dataset frame, then make ready for inference * fix: reducing reliance on lerobot record for policy's ouptuts too * fix: encapsulating squeezing out + device handling from predict action * fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion) * fix(policies): right utils signature + docstrings (#2198) --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
bf6ac5e110
commit
723013c71b
@@ -16,10 +16,16 @@
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.datasets.utils import build_dataset_frame
|
||||
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
|
||||
|
||||
def populate_queues(
|
||||
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
|
||||
@@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str])
|
||||
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||
|
||||
|
||||
# TODO(Steven): Move this function to a proper preprocessor step
|
||||
def prepare_observation_for_inference(
|
||||
observation: dict[str, np.ndarray],
|
||||
device: torch.device,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
) -> RobotObservation:
|
||||
"""Converts observation data to model-ready PyTorch tensors.
|
||||
|
||||
This function takes a dictionary of NumPy arrays, performs necessary
|
||||
preprocessing, and prepares it for model inference. The steps include:
|
||||
1. Converting NumPy arrays to PyTorch tensors.
|
||||
2. Normalizing and permuting image data (if any).
|
||||
3. Adding a batch dimension to each tensor.
|
||||
4. Moving all tensors to the specified compute device.
|
||||
5. Adding task and robot type information to the dictionary.
|
||||
|
||||
Args:
|
||||
observation: A dictionary mapping observation names (str) to NumPy
|
||||
array data. For images, the format is expected to be (H, W, C).
|
||||
device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the
|
||||
tensors will be moved.
|
||||
task: An optional string identifier for the current task.
|
||||
robot_type: An optional string identifier for the robot being used.
|
||||
|
||||
Returns:
|
||||
A dictionary where values are PyTorch tensors preprocessed for
|
||||
inference, residing on the target device. Image tensors are reshaped
|
||||
to (C, H, W) and normalized to a [0, 1] range.
|
||||
"""
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def build_inference_frame(
|
||||
observation: dict[str, Any],
|
||||
device: torch.device,
|
||||
ds_features: dict[str, dict],
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
) -> RobotObservation:
|
||||
"""Constructs a model-ready observation tensor dict from a raw observation.
|
||||
|
||||
This utility function orchestrates the process of converting a raw,
|
||||
unstructured observation from an environment into a structured,
|
||||
tensor-based format suitable for passing to a policy model.
|
||||
|
||||
Args:
|
||||
observation: The raw observation dictionary, which may contain
|
||||
superfluous keys.
|
||||
device: The target PyTorch device for the final tensors.
|
||||
ds_features: A configuration dictionary that specifies which features
|
||||
to extract from the raw observation.
|
||||
task: An optional string identifier for the current task.
|
||||
robot_type: An optional string identifier for the robot being used.
|
||||
|
||||
Returns:
|
||||
A dictionary of preprocessed tensors ready for model inference.
|
||||
"""
|
||||
# Extracts the correct keys from the incoming raw observation
|
||||
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)
|
||||
|
||||
# Performs the necessary conversions to the observation
|
||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction:
|
||||
"""Converts a policy's output tensor into a dictionary of named actions.
|
||||
|
||||
This function translates the numerical output from a policy model into a
|
||||
human-readable and robot-consumable format, where each dimension of the
|
||||
action tensor is mapped to a named motor or actuator command.
|
||||
|
||||
Args:
|
||||
action_tensor: A PyTorch tensor representing the policy's action,
|
||||
typically with a batch dimension (e.g., shape [1, action_dim]).
|
||||
ds_features: A configuration dictionary containing metadata, including
|
||||
the names corresponding to each index of the action tensor.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping action names (e.g., "joint_1_motor") to their
|
||||
corresponding floating-point values, ready to be sent to a robot
|
||||
controller.
|
||||
"""
|
||||
# TODO(Steven): Check if these steps are already in all postprocessor policies
|
||||
action_tensor = action_tensor.squeeze(0)
|
||||
action_tensor = action_tensor.to("cpu")
|
||||
|
||||
action_names = ds_features[ACTION]["names"]
|
||||
act_processed_policy: RobotAction = {
|
||||
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
|
||||
}
|
||||
return act_processed_policy
|
||||
|
||||
@@ -79,6 +79,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
@@ -316,10 +317,7 @@ def record_loop(
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
action_names = dataset.features[ACTION]["names"]
|
||||
act_processed_policy: RobotAction = {
|
||||
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
|
||||
}
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
@@ -31,6 +31,7 @@ from deepdiff import DeepDiff
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.robots import Robot
|
||||
|
||||
@@ -102,17 +103,7 @@ def predict_action(
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# Compute the next action with the policy
|
||||
@@ -121,12 +112,6 @@ def predict_action(
|
||||
|
||||
action = postprocessor(action)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
return action
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user