Streamline processor loading logic
This commit is contained in:
committed by
uzhilinsky
parent
fc7b7bc694
commit
06c632b144
@@ -104,11 +104,13 @@ The training config is used to determine which data transformations should be ap
|
||||
There are also a number of checkpoints that are available as exported JAX graphs, which we trained ourselves using our internal training code. These can be served using the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_aloha/model --policy.processor=trossen_biarm_single_base_cam_24dim
|
||||
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_aloha/model [--policy.processor=trossen_biarm_single_base_cam_24dim]
|
||||
```
|
||||
|
||||
In this case, the data transformations are taken from the default policy and the processor name will be used to determine which norms stats should be used to normalize the transformed data.
|
||||
|
||||
For these exported models, norm stats are loaded from processors that are exported along with the model, while data transformations are defined in the corresponding default policy (see `create_default_policy` in [scripts/serve_policy.py](scripts/serve_policy.py)). The processor name is optional, and if not provided, we will do the following:
|
||||
- Try using the default environment processor name
|
||||
- Load a processor if there is only one available
|
||||
- Raise an error if there are multiple processors available and ask to provide a processor name
|
||||
|
||||
### Running with Docker:
|
||||
|
||||
|
||||
@@ -36,7 +36,8 @@ class Exported:
|
||||
|
||||
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
|
||||
dir: str
|
||||
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
|
||||
# Processor name to load the norm stats from. If not provided, first try using the default environment processor.
|
||||
# If not available, load a processor if there is only one available. If there are multiple processors, raise an error.
|
||||
processor: str | None = None
|
||||
|
||||
|
||||
@@ -120,15 +121,32 @@ def create_default_policy(
|
||||
default_exported = DEFAULT_EXPORTED[env]
|
||||
if exported:
|
||||
checkpoint_dir = exported.dir
|
||||
processor = exported.processor or default_exported.processor
|
||||
processor = exported.processor
|
||||
else:
|
||||
checkpoint_dir = default_exported.dir
|
||||
processor = default_exported.processor
|
||||
assert processor, "Default processor must be always set"
|
||||
|
||||
logging.info("Loading model...")
|
||||
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
|
||||
|
||||
processors = model.processor_names()
|
||||
if not processors:
|
||||
raise ValueError(f"No processors found in {checkpoint_dir}")
|
||||
|
||||
if processor is None:
|
||||
# First try using the default environment processor.
|
||||
if default_exported.processor in processors:
|
||||
processor = default_exported.processor
|
||||
# If the default processor is not available, load a processor if there is only one available.
|
||||
elif len(processors) == 1:
|
||||
processor = processors[0]
|
||||
# If there are multiple processors, ask the user to provide a processor name.
|
||||
else:
|
||||
raise ValueError(f"Processor name must be provided. Available: {processors}")
|
||||
logging.info("Using processor: %s", default_exported.processor)
|
||||
elif processor not in processors:
|
||||
raise ValueError(f"Processor {processor} not found in {checkpoint_dir}, found {processors}")
|
||||
|
||||
def make_policy_config(
|
||||
input_layers: Sequence[transforms.DataTransformFn],
|
||||
output_layers: Sequence[transforms.DataTransformFn],
|
||||
|
||||
@@ -140,8 +140,14 @@ class PiModel(_model.BaseModel):
|
||||
return _example_to_obs(_make_batch(example))
|
||||
|
||||
def norm_stats(self, processor_name: str) -> dict[str, _normalize.NormStats]:
|
||||
"""Load the norm stats from the checkpoint."""
|
||||
return _import_norm_stats(self.ckpt_dir, processor_name)
|
||||
|
||||
def processor_names(self) -> list[str]:
|
||||
"""List of processor names available in the checkpoint."""
|
||||
processor_dir = self.ckpt_dir / "processors"
|
||||
return [x.name for x in processor_dir.iterdir() if x.is_dir()]
|
||||
|
||||
def set_module(self, module: common.BaseModule, param_path: str) -> _model.Model:
|
||||
"""Creates a new model that uses the same parameters but a different module.
|
||||
|
||||
@@ -253,6 +259,7 @@ def _example_to_obs(example: dict) -> common.Observation:
|
||||
|
||||
def _import_norm_stats(ckpt_dir: pathlib.Path | str, processor_name: str) -> dict[str, _normalize.NormStats]:
|
||||
ckpt_dir = pathlib.Path(ckpt_dir).resolve()
|
||||
|
||||
path = ckpt_dir / "processors" / processor_name
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Processor {processor_name} not found in {ckpt_dir}")
|
||||
|
||||
@@ -33,6 +33,14 @@ def test_exported_as_pi0():
|
||||
assert diff < 10.0
|
||||
|
||||
|
||||
def test_processor_loading():
|
||||
pi_model = exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
|
||||
assert pi_model.processor_names() == ["huggingface_aloha_sim_transfer_cube"]
|
||||
|
||||
norm_stats = pi_model.norm_stats("huggingface_aloha_sim_transfer_cube")
|
||||
assert sorted(norm_stats) == ["actions", "state"]
|
||||
|
||||
|
||||
def test_convert_to_openpi(tmp_path: pathlib.Path):
|
||||
output_dir = tmp_path / "output"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user