fix(scripts): add missing observation overwrite in eval and async (#2265)

This commit is contained in:
Steven Palma
2025-10-20 23:34:24 +02:00
committed by GitHub
parent 5f6f476f32
commit b954337ac7
4 changed files with 12 additions and 3 deletions

View File

@@ -16,7 +16,7 @@ import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
import torch
@@ -268,6 +268,7 @@ class RemotePolicyConfig:
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
rename_map: dict[str, str] = field(default_factory=dict)
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:

View File

@@ -159,7 +159,10 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
self.preprocessor, self.postprocessor = make_pre_post_processors(
self.policy.config,
pretrained_path=policy_specs.pretrained_name_or_path,
preprocessor_overrides={"device_processor": device_override},
preprocessor_overrides={
"device_processor": device_override,
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
},
postprocessor_overrides={"device_processor": device_override},
)

View File

@@ -36,6 +36,8 @@ class EvalPipelineConfig:
output_dir: Path | None = None
job_name: str | None = None
seed: int | None = 1000
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@@ -508,7 +508,10 @@ def eval_main(cfg: EvalPipelineConfig):
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
preprocessor_overrides={
"device_processor": {"device": str(policy.config.device)},
"rename_observations_processor": {"rename_map": cfg.rename_map},
},
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(