diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index f73cbc1da..2158f51ac 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -16,7 +16,7 @@ import logging import logging.handlers import os import time -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path import torch @@ -268,6 +268,7 @@ class RemotePolicyConfig: lerobot_features: dict[str, PolicyFeature] actions_per_chunk: int device: str = "cpu" + rename_map: dict[str, str] = field(default_factory=dict) def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool: diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index f7e00dea4..ab2e6bcd8 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -159,7 +159,10 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): self.preprocessor, self.postprocessor = make_pre_post_processors( self.policy.config, pretrained_path=policy_specs.pretrained_name_or_path, - preprocessor_overrides={"device_processor": device_override}, + preprocessor_overrides={ + "device_processor": device_override, + "rename_observations_processor": {"rename_map": policy_specs.rename_map}, + }, postprocessor_overrides={"device_processor": device_override}, ) diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index e9e05a7e8..2f085da56 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -36,6 +36,8 @@ class EvalPipelineConfig: output_dir: Path | None = None job_name: str | None = None seed: int | None = 1000 + # Rename map for the observation to override the image and state keys + rename_map: dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: # HACK: We parse again the cli args here to get the pretrained path if there was one. diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 0dec18be6..8796e897e 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -508,7 +508,10 @@ def eval_main(cfg: EvalPipelineConfig): policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. - preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, + preprocessor_overrides={ + "device_processor": {"device": str(policy.config.device)}, + "rename_observations_processor": {"rename_map": cfg.rename_map}, + }, ) with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy_all(