{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['HF_ENDPOINT'] = \"https://hf-mirror.com\"\n", "import dataclasses\n", "\n", "import jax\n", "\n", "from openpi.models import model as _model\n", "from openpi.policies import droid_policy\n", "from openpi.policies import policy_config as _policy_config\n", "from openpi.shared import download\n", "from openpi.training import config as _config\n", "from openpi.training import data_loader as _data_loader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Policy inference\n", "\n", "The following example shows how to create a policy from a checkpoint and run inference on a dummy example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fa8d45bf6fe5420f8b152ff52794ee45", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0.00/11.2G [00:00 6\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[43m_policy_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_trained_policy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m example \u001b[38;5;241m=\u001b[39m droid_policy\u001b[38;5;241m.\u001b[39mmake_droid_example()\n", "File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/policies/policy_config.py:72\u001b[0m, in \u001b[0;36mcreate_trained_policy\u001b[0;34m(train_config, checkpoint_dir, repack_transforms, sample_kwargs, default_prompt, norm_stats)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAsset id is required to load norm stats.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 64\u001b[0m norm_stats \u001b[38;5;241m=\u001b[39m _checkpoints\u001b[38;5;241m.\u001b[39mload_norm_stats(checkpoint_dir \u001b[38;5;241m/\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massets\u001b[39m\u001b[38;5;124m\"\u001b[39m, data_config\u001b[38;5;241m.\u001b[39masset_id)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _policy\u001b[38;5;241m.\u001b[39mPolicy(\n\u001b[1;32m 67\u001b[0m model,\n\u001b[1;32m 68\u001b[0m transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 69\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 70\u001b[0m transforms\u001b[38;5;241m.\u001b[39mInjectDefaultPrompt(default_prompt),\n\u001b[1;32m 71\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[0;32m---> 72\u001b[0m \u001b[43mtransforms\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mNormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnorm_stats\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_quantiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_quantile_norm\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 73\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 74\u001b[0m ],\n\u001b[1;32m 75\u001b[0m output_transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 76\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 77\u001b[0m transforms\u001b[38;5;241m.\u001b[39mUnnormalize(norm_stats, use_quantiles\u001b[38;5;241m=\u001b[39mdata_config\u001b[38;5;241m.\u001b[39muse_quantile_norm),\n\u001b[1;32m 78\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 79\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 80\u001b[0m ],\n\u001b[1;32m 81\u001b[0m sample_kwargs\u001b[38;5;241m=\u001b[39msample_kwargs,\n\u001b[1;32m 82\u001b[0m metadata\u001b[38;5;241m=\u001b[39mtrain_config\u001b[38;5;241m.\u001b[39mpolicy_metadata,\n\u001b[1;32m 83\u001b[0m )\n", "File \u001b[0;32m:6\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, norm_stats, use_quantiles, strict)\u001b[0m\n", "File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:124\u001b[0m, in \u001b[0;36mNormalize.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__post_init__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm_stats \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_quantiles:\n\u001b[0;32m--> 124\u001b[0m \u001b[43m_assert_quantile_stats\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_stats\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:431\u001b[0m, in \u001b[0;36m_assert_quantile_stats\u001b[0;34m(norm_stats)\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m flatten_dict(norm_stats)\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq01 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq99 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 431\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 432\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantile stats must be provided if use_quantile_norm is True. Key \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is missing q01 or q99.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 433\u001b[0m )\n", "\u001b[0;31mValueError\u001b[0m: quantile stats must be provided if use_quantile_norm is True. Key actions is missing q01 or q99." ] } ], "source": [ "\n", "config = _config.get_config(\"pi0_fast_droid\")\n", "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n", "# checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_base\")\n", "\n", "# Create a trained policy.\n", "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n", "\n", "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n", "example = droid_policy.make_droid_example()\n", "result = policy.infer(example)\n", "\n", "# Delete the policy to free up memory.\n", "del policy\n", "\n", "print(\"Actions shape:\", result[\"actions\"].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Working with a live model\n", "\n", "\n", "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = _config.get_config(\"pi0_aloha_sim\")\n", "\n", "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n", "key = jax.random.key(0)\n", "\n", "# Create a model from the checkpoint.\n", "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n", "\n", "# We can create fake observations and actions to test the model.\n", "obs, act = config.model.fake_obs(), config.model.fake_act()\n", "\n", "# Sample actions from the model.\n", "loss = model.compute_loss(key, obs, act)\n", "print(\"Loss shape:\", loss.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we are going to create a data loader and use a real batch of training data to compute the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reduce the batch size to reduce memory usage.\n", "config = dataclasses.replace(config, batch_size=2)\n", "\n", "# Load a single batch of data. This is the same data that will be used during training.\n", "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n", "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n", "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n", "obs, act = next(iter(loader))\n", "\n", "# Sample actions from the model.\n", "loss = model.compute_loss(key, obs, act)\n", "\n", "# Delete the model to free up memory.\n", "del model\n", "\n", "print(\"Loss shape:\", loss.shape)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.12" } }, "nbformat": 4, "nbformat_minor": 2 }