fix(scripts): add missing observation overwrite in eval and async (#2265)

This commit is contained in:
Steven Palma
2025-10-20 23:34:24 +02:00
committed by GitHub
parent 5f6f476f32
commit b954337ac7
4 changed files with 12 additions and 3 deletions

View File

@@ -16,7 +16,7 @@ import logging
import logging.handlers import logging.handlers
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
import torch import torch
@@ -268,6 +268,7 @@ class RemotePolicyConfig:
lerobot_features: dict[str, PolicyFeature] lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int actions_per_chunk: int
device: str = "cpu" device: str = "cpu"
rename_map: dict[str, str] = field(default_factory=dict)
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool: def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:

View File

@@ -159,7 +159,10 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
self.preprocessor, self.postprocessor = make_pre_post_processors( self.preprocessor, self.postprocessor = make_pre_post_processors(
self.policy.config, self.policy.config,
pretrained_path=policy_specs.pretrained_name_or_path, pretrained_path=policy_specs.pretrained_name_or_path,
preprocessor_overrides={"device_processor": device_override}, preprocessor_overrides={
"device_processor": device_override,
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
},
postprocessor_overrides={"device_processor": device_override}, postprocessor_overrides={"device_processor": device_override},
) )

View File

@@ -36,6 +36,8 @@ class EvalPipelineConfig:
output_dir: Path | None = None output_dir: Path | None = None
job_name: str | None = None job_name: str | None = None
seed: int | None = 1000 seed: int | None = 1000
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one. # HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@@ -508,7 +508,10 @@ def eval_main(cfg: EvalPipelineConfig):
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path, pretrained_path=cfg.policy.pretrained_path,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, preprocessor_overrides={
"device_processor": {"device": str(policy.config.device)},
"rename_observations_processor": {"rename_map": cfg.rename_map},
},
) )
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all( info = eval_policy_all(