Streamline processor loading logic

This commit is contained in:
Ury Zhilinsky
2024-12-25 15:19:42 -08:00
committed by uzhilinsky
parent fc7b7bc694
commit 06c632b144
4 changed files with 41 additions and 6 deletions

View File

@@ -36,7 +36,8 @@ class Exported:
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
dir: str
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
# Processor name to load the norm stats from. If not provided, first try using the default environment processor.
# If not available, load a processor if there is only one available. If there are multiple processors, raise an error.
processor: str | None = None
@@ -120,15 +121,32 @@ def create_default_policy(
default_exported = DEFAULT_EXPORTED[env]
if exported:
checkpoint_dir = exported.dir
processor = exported.processor or default_exported.processor
processor = exported.processor
else:
checkpoint_dir = default_exported.dir
processor = default_exported.processor
assert processor, "Default processor must be always set"
logging.info("Loading model...")
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
processors = model.processor_names()
if not processors:
raise ValueError(f"No processors found in {checkpoint_dir}")
if processor is None:
# First try using the default environment processor.
if default_exported.processor in processors:
processor = default_exported.processor
# If the default processor is not available, load a processor if there is only one available.
elif len(processors) == 1:
processor = processors[0]
# If there are multiple processors, ask the user to provide a processor name.
else:
raise ValueError(f"Processor name must be provided. Available: {processors}")
logging.info("Using processor: %s", default_exported.processor)
elif processor not in processors:
raise ValueError(f"Processor {processor} not found in {checkpoint_dir}, found {processors}")
def make_policy_config(
input_layers: Sequence[transforms.DataTransformFn],
output_layers: Sequence[transforms.DataTransformFn],