138 lines
4.0 KiB
Plaintext
Executable File
138 lines
4.0 KiB
Plaintext
Executable File
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import dataclasses\n",
|
|
"\n",
|
|
"import jax\n",
|
|
"\n",
|
|
"from openpi.models import model as _model\n",
|
|
"from openpi.policies import droid_policy\n",
|
|
"from openpi.policies import policy_config as _policy_config\n",
|
|
"from openpi.shared import download\n",
|
|
"from openpi.training import config as _config\n",
|
|
"from openpi.training import data_loader as _data_loader"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Policy inference\n",
|
|
"\n",
|
|
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
|
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
|
"\n",
|
|
"# Create a trained policy.\n",
|
|
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
|
"\n",
|
|
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
|
|
"example = droid_policy.make_droid_example()\n",
|
|
"result = policy.infer(example)\n",
|
|
"\n",
|
|
"# Delete the policy to free up memory.\n",
|
|
"del policy\n",
|
|
"\n",
|
|
"print(\"Actions shape:\", result[\"actions\"].shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Working with a live model\n",
|
|
"\n",
|
|
"\n",
|
|
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"config = _config.get_config(\"pi0_aloha_sim\")\n",
|
|
"\n",
|
|
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
|
|
"key = jax.random.key(0)\n",
|
|
"\n",
|
|
"# Create a model from the checkpoint.\n",
|
|
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
|
|
"\n",
|
|
"# We can create fake observations and actions to test the model.\n",
|
|
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
|
|
"\n",
|
|
"# Sample actions from the model.\n",
|
|
"loss = model.compute_loss(key, obs, act)\n",
|
|
"print(\"Loss shape:\", loss.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Reduce the batch size to reduce memory usage.\n",
|
|
"config = dataclasses.replace(config, batch_size=2)\n",
|
|
"\n",
|
|
"# Load a single batch of data. This is the same data that will be used during training.\n",
|
|
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
|
|
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
|
|
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
|
|
"obs, act = next(iter(loader))\n",
|
|
"\n",
|
|
"# Sample actions from the model.\n",
|
|
"loss = model.compute_loss(key, obs, act)\n",
|
|
"\n",
|
|
"# Delete the model to free up memory.\n",
|
|
"del model\n",
|
|
"\n",
|
|
"print(\"Loss shape:\", loss.shape)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|