Files
openpi/examples/inference.ipynb
lzy 65d864861b
Some checks are pending
pre-commit / pre-commit (push) Waiting to run
add
2025-04-26 22:10:42 +08:00

138 lines
4.0 KiB
Plaintext
Executable File

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import dataclasses\n",
"\n",
"import jax\n",
"\n",
"from openpi.models import model as _model\n",
"from openpi.policies import droid_policy\n",
"from openpi.policies import policy_config as _policy_config\n",
"from openpi.shared import download\n",
"from openpi.training import config as _config\n",
"from openpi.training import data_loader as _data_loader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Policy inference\n",
"\n",
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_fast_droid\")\n",
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
"\n",
"# Create a trained policy.\n",
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
"\n",
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
"example = droid_policy.make_droid_example()\n",
"result = policy.infer(example)\n",
"\n",
"# Delete the policy to free up memory.\n",
"del policy\n",
"\n",
"print(\"Actions shape:\", result[\"actions\"].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Working with a live model\n",
"\n",
"\n",
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_aloha_sim\")\n",
"\n",
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
"key = jax.random.key(0)\n",
"\n",
"# Create a model from the checkpoint.\n",
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
"\n",
"# We can create fake observations and actions to test the model.\n",
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"print(\"Loss shape:\", loss.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Reduce the batch size to reduce memory usage.\n",
"config = dataclasses.replace(config, batch_size=2)\n",
"\n",
"# Load a single batch of data. This is the same data that will be used during training.\n",
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
"obs, act = next(iter(loader))\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"\n",
"# Delete the model to free up memory.\n",
"del model\n",
"\n",
"print(\"Loss shape:\", loss.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}