feat(scripts): Introduce build_inference_frame/make_robot_action util to easily allow API-based Inference (#2143)
* fix: expose a function explicitly building a frame for inference * fix: first make dataset frame, then make ready for inference * fix: reducing reliance on lerobot record for policy's ouptuts too * fix: encapsulating squeezing out + device handling from predict action * fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion) * fix(policies): right utils signature + docstrings (#2198) --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
bf6ac5e110
commit
723013c71b
@@ -16,10 +16,16 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from lerobot.datasets.utils import build_dataset_frame
|
||||||
|
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
|
||||||
|
|
||||||
def populate_queues(
|
def populate_queues(
|
||||||
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
|
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
|
||||||
@@ -85,3 +91,110 @@ def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str])
|
|||||||
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||||
if unexpected_keys:
|
if unexpected_keys:
|
||||||
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Move this function to a proper preprocessor step
|
||||||
|
def prepare_observation_for_inference(
|
||||||
|
observation: dict[str, np.ndarray],
|
||||||
|
device: torch.device,
|
||||||
|
task: str | None = None,
|
||||||
|
robot_type: str | None = None,
|
||||||
|
) -> RobotObservation:
|
||||||
|
"""Converts observation data to model-ready PyTorch tensors.
|
||||||
|
|
||||||
|
This function takes a dictionary of NumPy arrays, performs necessary
|
||||||
|
preprocessing, and prepares it for model inference. The steps include:
|
||||||
|
1. Converting NumPy arrays to PyTorch tensors.
|
||||||
|
2. Normalizing and permuting image data (if any).
|
||||||
|
3. Adding a batch dimension to each tensor.
|
||||||
|
4. Moving all tensors to the specified compute device.
|
||||||
|
5. Adding task and robot type information to the dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: A dictionary mapping observation names (str) to NumPy
|
||||||
|
array data. For images, the format is expected to be (H, W, C).
|
||||||
|
device: The PyTorch device (e.g., 'cpu' or 'cuda') to which the
|
||||||
|
tensors will be moved.
|
||||||
|
task: An optional string identifier for the current task.
|
||||||
|
robot_type: An optional string identifier for the robot being used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary where values are PyTorch tensors preprocessed for
|
||||||
|
inference, residing on the target device. Image tensors are reshaped
|
||||||
|
to (C, H, W) and normalized to a [0, 1] range.
|
||||||
|
"""
|
||||||
|
for name in observation:
|
||||||
|
observation[name] = torch.from_numpy(observation[name])
|
||||||
|
if "image" in name:
|
||||||
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
observation["task"] = task if task else ""
|
||||||
|
observation["robot_type"] = robot_type if robot_type else ""
|
||||||
|
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
def build_inference_frame(
|
||||||
|
observation: dict[str, Any],
|
||||||
|
device: torch.device,
|
||||||
|
ds_features: dict[str, dict],
|
||||||
|
task: str | None = None,
|
||||||
|
robot_type: str | None = None,
|
||||||
|
) -> RobotObservation:
|
||||||
|
"""Constructs a model-ready observation tensor dict from a raw observation.
|
||||||
|
|
||||||
|
This utility function orchestrates the process of converting a raw,
|
||||||
|
unstructured observation from an environment into a structured,
|
||||||
|
tensor-based format suitable for passing to a policy model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: The raw observation dictionary, which may contain
|
||||||
|
superfluous keys.
|
||||||
|
device: The target PyTorch device for the final tensors.
|
||||||
|
ds_features: A configuration dictionary that specifies which features
|
||||||
|
to extract from the raw observation.
|
||||||
|
task: An optional string identifier for the current task.
|
||||||
|
robot_type: An optional string identifier for the robot being used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of preprocessed tensors ready for model inference.
|
||||||
|
"""
|
||||||
|
# Extracts the correct keys from the incoming raw observation
|
||||||
|
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)
|
||||||
|
|
||||||
|
# Performs the necessary conversions to the observation
|
||||||
|
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||||
|
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict]) -> RobotAction:
|
||||||
|
"""Converts a policy's output tensor into a dictionary of named actions.
|
||||||
|
|
||||||
|
This function translates the numerical output from a policy model into a
|
||||||
|
human-readable and robot-consumable format, where each dimension of the
|
||||||
|
action tensor is mapped to a named motor or actuator command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_tensor: A PyTorch tensor representing the policy's action,
|
||||||
|
typically with a batch dimension (e.g., shape [1, action_dim]).
|
||||||
|
ds_features: A configuration dictionary containing metadata, including
|
||||||
|
the names corresponding to each index of the action tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping action names (e.g., "joint_1_motor") to their
|
||||||
|
corresponding floating-point values, ready to be sent to a robot
|
||||||
|
controller.
|
||||||
|
"""
|
||||||
|
# TODO(Steven): Check if these steps are already in all postprocessor policies
|
||||||
|
action_tensor = action_tensor.squeeze(0)
|
||||||
|
action_tensor = action_tensor.to("cpu")
|
||||||
|
|
||||||
|
action_names = ds_features[ACTION]["names"]
|
||||||
|
act_processed_policy: RobotAction = {
|
||||||
|
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
|
||||||
|
}
|
||||||
|
return act_processed_policy
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
|||||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
PolicyAction,
|
PolicyAction,
|
||||||
PolicyProcessorPipeline,
|
PolicyProcessorPipeline,
|
||||||
@@ -316,10 +317,7 @@ def record_loop(
|
|||||||
robot_type=robot.robot_type,
|
robot_type=robot.robot_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
action_names = dataset.features[ACTION]["names"]
|
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||||
act_processed_policy: RobotAction = {
|
|
||||||
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
|
|
||||||
}
|
|
||||||
|
|
||||||
elif policy is None and isinstance(teleop, Teleoperator):
|
elif policy is None and isinstance(teleop, Teleoperator):
|
||||||
act = teleop.get_action()
|
act = teleop.get_action()
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from deepdiff import DeepDiff
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.policies.utils import prepare_observation_for_inference
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
from lerobot.robots import Robot
|
from lerobot.robots import Robot
|
||||||
|
|
||||||
@@ -102,17 +103,7 @@ def predict_action(
|
|||||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
for name in observation:
|
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||||
observation[name] = torch.from_numpy(observation[name])
|
|
||||||
if "image" in name:
|
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
|
||||||
observation[name] = observation[name].to(device)
|
|
||||||
|
|
||||||
observation["task"] = task if task else ""
|
|
||||||
observation["robot_type"] = robot_type if robot_type else ""
|
|
||||||
|
|
||||||
observation = preprocessor(observation)
|
observation = preprocessor(observation)
|
||||||
|
|
||||||
# Compute the next action with the policy
|
# Compute the next action with the policy
|
||||||
@@ -121,12 +112,6 @@ def predict_action(
|
|||||||
|
|
||||||
action = postprocessor(action)
|
action = postprocessor(action)
|
||||||
|
|
||||||
# Remove batch dimension
|
|
||||||
action = action.squeeze(0)
|
|
||||||
|
|
||||||
# Move to cpu, if not already the case
|
|
||||||
action = action.to("cpu")
|
|
||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user