fix(scripts): add missing observation overwrite in eval and async (#2265)
This commit is contained in:
@@ -16,7 +16,7 @@ import logging
|
|||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -268,6 +268,7 @@ class RemotePolicyConfig:
|
|||||||
lerobot_features: dict[str, PolicyFeature]
|
lerobot_features: dict[str, PolicyFeature]
|
||||||
actions_per_chunk: int
|
actions_per_chunk: int
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||||
|
|||||||
@@ -159,7 +159,10 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
|||||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||||
self.policy.config,
|
self.policy.config,
|
||||||
pretrained_path=policy_specs.pretrained_name_or_path,
|
pretrained_path=policy_specs.pretrained_name_or_path,
|
||||||
preprocessor_overrides={"device_processor": device_override},
|
preprocessor_overrides={
|
||||||
|
"device_processor": device_override,
|
||||||
|
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
|
||||||
|
},
|
||||||
postprocessor_overrides={"device_processor": device_override},
|
postprocessor_overrides={"device_processor": device_override},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ class EvalPipelineConfig:
|
|||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
job_name: str | None = None
|
job_name: str | None = None
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
# Rename map for the observation to override the image and state keys
|
||||||
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
|
|||||||
@@ -508,7 +508,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
policy_cfg=cfg.policy,
|
policy_cfg=cfg.policy,
|
||||||
pretrained_path=cfg.policy.pretrained_path,
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
preprocessor_overrides={
|
||||||
|
"device_processor": {"device": str(policy.config.device)},
|
||||||
|
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy_all(
|
info = eval_policy_all(
|
||||||
|
|||||||
Reference in New Issue
Block a user