281 lines
9.4 KiB
Python
281 lines
9.4 KiB
Python
from collections.abc import Sequence
|
|
import dataclasses
|
|
import enum
|
|
import logging
|
|
from typing import Any
|
|
|
|
import tyro
|
|
|
|
from openpi import transforms
|
|
from openpi.models import exported as _exported
|
|
from openpi.models import model as _model
|
|
from openpi.policies import aloha_policy
|
|
from openpi.policies import calvin_policy
|
|
from openpi.policies import droid_policy, ur_policy
|
|
from openpi.policies import libero_policy
|
|
from openpi.policies import policy as _policy
|
|
from openpi.policies import policy_config as _policy_config
|
|
from openpi.serving import websocket_policy_server
|
|
from openpi.shared import delta_actions
|
|
from openpi.training import config as _config
|
|
|
|
|
|
class EnvMode(enum.Enum):
|
|
"""Supported environments."""
|
|
|
|
ALOHA = "aloha"
|
|
ALOHA_SIM = "aloha_sim"
|
|
DROID = "droid"
|
|
CALVIN = "calvin"
|
|
LIBERO = "libero"
|
|
UR = "ur"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Exported:
|
|
"""Load an exported checkpoint."""
|
|
|
|
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
|
|
dir: str
|
|
# Processor name to load the norm stats from. If not provided, first try using the default environment processor.
|
|
# If not available, load a processor if there is only one available. If there are multiple processors, raise an error.
|
|
processor: str | None = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Checkpoint:
|
|
"""Load a policy from a trained checkpoint."""
|
|
|
|
# Training config name (e.g., "pi0_aloha_sim").
|
|
config: str
|
|
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
|
dir: str
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Args:
|
|
"""Arguments for the serve_policy script."""
|
|
|
|
# Environment to serve the policy for.
|
|
env: EnvMode = EnvMode.ALOHA_SIM
|
|
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
|
policy: Checkpoint | Exported | None = None
|
|
|
|
# If provided, overrides the default prompt for the policy.
|
|
default_prompt: str | None = None
|
|
|
|
# Port to serve the policy on.
|
|
port: int = 8000
|
|
# Record the policy's behavior for debugging.
|
|
record: bool = False
|
|
|
|
|
|
def repack_from_env(env: EnvMode) -> transforms.Group:
|
|
"""Creates environment specific repack transforms."""
|
|
# TODO(ury): Move this to the runtime.
|
|
match env:
|
|
case EnvMode.ALOHA:
|
|
return transforms.Group(
|
|
inputs=[aloha_policy.ActInputsRepack()],
|
|
outputs=[aloha_policy.ActOutputsRepack()],
|
|
)
|
|
case EnvMode.ALOHA_SIM:
|
|
return transforms.Group(
|
|
inputs=[aloha_policy.ActInputsRepack()],
|
|
outputs=[aloha_policy.ActOutputsRepack()],
|
|
)
|
|
case _:
|
|
return transforms.Group()
|
|
|
|
|
|
# Default exported models.
|
|
DEFAULT_EXPORTED: dict[EnvMode, Exported] = {
|
|
EnvMode.ALOHA: Exported(
|
|
dir="s3://openpi-assets/exported/pi0_aloha/model",
|
|
processor="trossen_biarm_single_base_cam_24dim",
|
|
),
|
|
EnvMode.ALOHA_SIM: Exported(
|
|
dir="s3://openpi-assets/exported/pi0_aloha_sim/model",
|
|
processor="huggingface_aloha_sim_transfer_cube",
|
|
),
|
|
EnvMode.DROID: Exported(
|
|
dir="s3://openpi-assets/exported/pi0_droid/model",
|
|
processor="openx_droid",
|
|
),
|
|
EnvMode.CALVIN: Exported(
|
|
dir="s3://openpi-assets/exported/pi0_calvin/model",
|
|
processor="calvin",
|
|
),
|
|
EnvMode.LIBERO: Exported(
|
|
dir="s3://openpi-assets/exported/pi0_libero/model",
|
|
processor="libero",
|
|
),
|
|
EnvMode.UR: Exported(
|
|
dir="s3://openpi-assets/exported/pi0_base/model",
|
|
processor="ur5_single_24dim"
|
|
)
|
|
}
|
|
|
|
|
|
def create_default_policy(
|
|
env: EnvMode, *, default_prompt: str | None = None, exported: Exported | None = None
|
|
) -> _policy.Policy:
|
|
model: _model.BaseModel
|
|
config: _policy_config.PolicyConfig
|
|
|
|
default_exported = DEFAULT_EXPORTED[env]
|
|
if exported:
|
|
checkpoint_dir = exported.dir
|
|
processor = exported.processor
|
|
else:
|
|
checkpoint_dir = default_exported.dir
|
|
processor = default_exported.processor
|
|
|
|
logging.info("Loading model...")
|
|
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
|
|
|
|
processors = model.processor_names()
|
|
if not processors:
|
|
raise ValueError(f"No processors found in {checkpoint_dir}")
|
|
|
|
if processor is None:
|
|
# First try using the default environment processor.
|
|
if default_exported.processor in processors:
|
|
processor = default_exported.processor
|
|
# If the default processor is not available, load a processor if there is only one available.
|
|
elif len(processors) == 1:
|
|
processor = processors[0]
|
|
# If there are multiple processors, ask the user to provide a processor name.
|
|
else:
|
|
raise ValueError(f"Processor name must be provided. Available: {processors}")
|
|
logging.info("Using processor: %s", default_exported.processor)
|
|
elif processor not in processors:
|
|
raise ValueError(f"Processor {processor} not found in {checkpoint_dir}, found {processors}")
|
|
|
|
def make_policy_config(
|
|
input_layers: Sequence[transforms.DataTransformFn],
|
|
output_layers: Sequence[transforms.DataTransformFn],
|
|
sample_kwargs: dict[str, Any] | None = None,
|
|
):
|
|
sample_kwargs = sample_kwargs or {"num_steps": 10}
|
|
return _policy_config.PolicyConfig(
|
|
model=model,
|
|
norm_stats=model.norm_stats(processor),
|
|
default_prompt=default_prompt,
|
|
input_layers=input_layers,
|
|
output_layers=output_layers,
|
|
sample_kwargs=sample_kwargs,
|
|
)
|
|
|
|
logging.info("Creating policy...")
|
|
match env:
|
|
case EnvMode.ALOHA:
|
|
delta_action_mask = delta_actions.make_bool_mask(6, -1, 6, -1)
|
|
config = make_policy_config(
|
|
input_layers=[
|
|
aloha_policy.ActInputsRepack(),
|
|
aloha_policy.AlohaInputs(
|
|
action_dim=model.action_dim,
|
|
delta_action_mask=delta_action_mask,
|
|
adapt_to_pi=True,
|
|
),
|
|
],
|
|
output_layers=[
|
|
aloha_policy.AlohaOutputs(
|
|
delta_action_mask=delta_action_mask,
|
|
adapt_to_pi=True,
|
|
),
|
|
aloha_policy.ActOutputsRepack(),
|
|
],
|
|
)
|
|
case EnvMode.ALOHA_SIM:
|
|
config = make_policy_config(
|
|
input_layers=[
|
|
aloha_policy.ActInputsRepack(),
|
|
aloha_policy.AlohaInputs(action_dim=model.action_dim),
|
|
],
|
|
output_layers=[
|
|
aloha_policy.AlohaOutputs(),
|
|
aloha_policy.ActOutputsRepack(),
|
|
],
|
|
)
|
|
case EnvMode.DROID:
|
|
config = make_policy_config(
|
|
input_layers=[
|
|
droid_policy.DroidInputs(action_dim=model.action_dim),
|
|
],
|
|
output_layers=[
|
|
droid_policy.DroidOutputs(),
|
|
transforms.SubsampleActions(stride=5),
|
|
],
|
|
)
|
|
case EnvMode.CALVIN:
|
|
config = make_policy_config(
|
|
input_layers=[
|
|
calvin_policy.CalvinInputs(action_dim=model.action_dim),
|
|
],
|
|
output_layers=[
|
|
calvin_policy.CalvinOutputs(),
|
|
],
|
|
)
|
|
case EnvMode.LIBERO:
|
|
config = make_policy_config(
|
|
input_layers=[
|
|
libero_policy.LiberoInputs(action_dim=model.action_dim),
|
|
],
|
|
output_layers=[
|
|
libero_policy.LiberoOutputs(),
|
|
],
|
|
)
|
|
case EnvMode.UR:
|
|
delta_action_mask = delta_actions.make_bool_mask(6, -1)
|
|
|
|
config = make_policy_config(
|
|
input_layers=[
|
|
ur_policy.URInputs(action_dim=model.action_dim),
|
|
transforms.ResizeImages(224,224),
|
|
],
|
|
output_layers=[
|
|
ur_policy.UROutputs(
|
|
delta_action_mask=delta_action_mask,
|
|
)
|
|
],
|
|
)
|
|
case _:
|
|
raise ValueError(f"Unknown environment mode: {env}")
|
|
return _policy_config.create_policy(config)
|
|
|
|
|
|
def create_policy(args: Args) -> _policy.Policy:
|
|
match args.policy:
|
|
case Checkpoint():
|
|
return _policy_config.create_trained_policy(
|
|
_config.get_config(args.policy.config),
|
|
args.policy.dir,
|
|
repack_transforms=repack_from_env(args.env),
|
|
default_prompt=args.default_prompt,
|
|
)
|
|
case Exported():
|
|
return create_default_policy(args.env, default_prompt=args.default_prompt, exported=args.policy)
|
|
case None:
|
|
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
|
|
|
|
|
def main(args: Args) -> None:
|
|
policy = create_policy(args)
|
|
|
|
# Record the policy's behavior.
|
|
if args.record:
|
|
policy = _policy.PolicyRecorder(policy, "policy_records")
|
|
|
|
logging.info("Creating server...")
|
|
server = websocket_policy_server.WebsocketPolicyServer(policy=policy, host="0.0.0.0", port=args.port)
|
|
|
|
logging.info("Serving...")
|
|
server.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO, force=True)
|
|
main(tyro.cli(Args))
|