Streamline processor loading logic
This commit is contained in:
committed by
uzhilinsky
parent
fc7b7bc694
commit
06c632b144
@@ -36,7 +36,8 @@ class Exported:
|
||||
|
||||
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
|
||||
dir: str
|
||||
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
|
||||
# Processor name to load the norm stats from. If not provided, first try using the default environment processor.
|
||||
# If not available, load a processor if there is only one available. If there are multiple processors, raise an error.
|
||||
processor: str | None = None
|
||||
|
||||
|
||||
@@ -120,15 +121,32 @@ def create_default_policy(
|
||||
default_exported = DEFAULT_EXPORTED[env]
|
||||
if exported:
|
||||
checkpoint_dir = exported.dir
|
||||
processor = exported.processor or default_exported.processor
|
||||
processor = exported.processor
|
||||
else:
|
||||
checkpoint_dir = default_exported.dir
|
||||
processor = default_exported.processor
|
||||
assert processor, "Default processor must be always set"
|
||||
|
||||
logging.info("Loading model...")
|
||||
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
|
||||
|
||||
processors = model.processor_names()
|
||||
if not processors:
|
||||
raise ValueError(f"No processors found in {checkpoint_dir}")
|
||||
|
||||
if processor is None:
|
||||
# First try using the default environment processor.
|
||||
if default_exported.processor in processors:
|
||||
processor = default_exported.processor
|
||||
# If the default processor is not available, load a processor if there is only one available.
|
||||
elif len(processors) == 1:
|
||||
processor = processors[0]
|
||||
# If there are multiple processors, ask the user to provide a processor name.
|
||||
else:
|
||||
raise ValueError(f"Processor name must be provided. Available: {processors}")
|
||||
logging.info("Using processor: %s", default_exported.processor)
|
||||
elif processor not in processors:
|
||||
raise ValueError(f"Processor {processor} not found in {checkpoint_dir}, found {processors}")
|
||||
|
||||
def make_policy_config(
|
||||
input_layers: Sequence[transforms.DataTransformFn],
|
||||
output_layers: Sequence[transforms.DataTransformFn],
|
||||
|
||||
Reference in New Issue
Block a user