Streamline processor loading logic

This commit is contained in:
Ury Zhilinsky
2024-12-25 15:19:42 -08:00
committed by uzhilinsky
parent fc7b7bc694
commit 06c632b144
4 changed files with 41 additions and 6 deletions

View File

@@ -104,11 +104,13 @@ The training config is used to determine which data transformations should be ap
There are also a number of checkpoints that are available as exported JAX graphs, which we trained ourselves using our internal training code. These can be served using the following command:
```bash
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_aloha/model --policy.processor=trossen_biarm_single_base_cam_24dim
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_aloha/model [--policy.processor=trossen_biarm_single_base_cam_24dim]
```
In this case, the data transformations are taken from the default policy and the processor name will be used to determine which norms stats should be used to normalize the transformed data.
For these exported models, norm stats are loaded from processors that are exported along with the model, while data transformations are defined in the corresponding default policy (see `create_default_policy` in [scripts/serve_policy.py](scripts/serve_policy.py)). The processor name is optional, and if not provided, we will do the following:
- Try using the default environment processor name
- Load a processor if there is only one available
- Raise an error if there are multiple processors available and ask to provide a processor name
### Running with Docker:

View File

@@ -36,7 +36,8 @@ class Exported:
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
dir: str
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
# Processor name to load the norm stats from. If not provided, first try using the default environment processor.
# If not available, load a processor if there is only one available. If there are multiple processors, raise an error.
processor: str | None = None
@@ -120,15 +121,32 @@ def create_default_policy(
default_exported = DEFAULT_EXPORTED[env]
if exported:
checkpoint_dir = exported.dir
processor = exported.processor or default_exported.processor
processor = exported.processor
else:
checkpoint_dir = default_exported.dir
processor = default_exported.processor
assert processor, "Default processor must be always set"
logging.info("Loading model...")
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
processors = model.processor_names()
if not processors:
raise ValueError(f"No processors found in {checkpoint_dir}")
if processor is None:
# First try using the default environment processor.
if default_exported.processor in processors:
processor = default_exported.processor
# If the default processor is not available, load a processor if there is only one available.
elif len(processors) == 1:
processor = processors[0]
# If there are multiple processors, ask the user to provide a processor name.
else:
raise ValueError(f"Processor name must be provided. Available: {processors}")
logging.info("Using processor: %s", default_exported.processor)
elif processor not in processors:
raise ValueError(f"Processor {processor} not found in {checkpoint_dir}, found {processors}")
def make_policy_config(
input_layers: Sequence[transforms.DataTransformFn],
output_layers: Sequence[transforms.DataTransformFn],

View File

@@ -140,8 +140,14 @@ class PiModel(_model.BaseModel):
return _example_to_obs(_make_batch(example))
def norm_stats(self, processor_name: str) -> dict[str, _normalize.NormStats]:
"""Load the norm stats from the checkpoint."""
return _import_norm_stats(self.ckpt_dir, processor_name)
def processor_names(self) -> list[str]:
"""List of processor names available in the checkpoint."""
processor_dir = self.ckpt_dir / "processors"
return [x.name for x in processor_dir.iterdir() if x.is_dir()]
def set_module(self, module: common.BaseModule, param_path: str) -> _model.Model:
"""Creates a new model that uses the same parameters but a different module.
@@ -253,6 +259,7 @@ def _example_to_obs(example: dict) -> common.Observation:
def _import_norm_stats(ckpt_dir: pathlib.Path | str, processor_name: str) -> dict[str, _normalize.NormStats]:
ckpt_dir = pathlib.Path(ckpt_dir).resolve()
path = ckpt_dir / "processors" / processor_name
if not path.exists():
raise FileNotFoundError(f"Processor {processor_name} not found in {ckpt_dir}")

View File

@@ -33,6 +33,14 @@ def test_exported_as_pi0():
assert diff < 10.0
def test_processor_loading():
pi_model = exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
assert pi_model.processor_names() == ["huggingface_aloha_sim_transfer_cube"]
norm_stats = pi_model.norm_stats("huggingface_aloha_sim_transfer_cube")
assert sorted(norm_stats) == ["actions", "state"]
def test_convert_to_openpi(tmp_path: pathlib.Path):
output_dir = tmp_path / "output"