4090 work
Some checks are pending
pre-commit / pre-commit (push) Waiting to run

This commit is contained in:
2025-04-26 22:25:28 +08:00
parent 65d864861b
commit 55fed92ccc
2 changed files with 62 additions and 4 deletions

View File

@@ -6,6 +6,8 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['HF_ENDPOINT'] = \"https://hf-mirror.com\"\n",
"import dataclasses\n",
"\n",
"import jax\n",
@@ -18,6 +20,13 @@
"from openpi.training import data_loader as _data_loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
@@ -31,10 +40,53 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fa8d45bf6fe5420f8b152ff52794ee45",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0.00/11.2G [00:00<?, ?iB/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
"Some kwargs in processor config are unused and will not have any effect: action_dim, scale, time_horizon, vocab_size, min_token. \n",
"Some kwargs in processor config are unused and will not have any effect: action_dim, scale, time_horizon, vocab_size, min_token. \n"
]
},
{
"ename": "ValueError",
"evalue": "quantile stats must be provided if use_quantile_norm is True. Key actions is missing q01 or q99.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m checkpoint_dir \u001b[38;5;241m=\u001b[39m download\u001b[38;5;241m.\u001b[39mmaybe_download(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ms3://openpi-assets/checkpoints/pi0_base\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Create a trained policy.\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[43m_policy_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_trained_policy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m example \u001b[38;5;241m=\u001b[39m droid_policy\u001b[38;5;241m.\u001b[39mmake_droid_example()\n",
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/policies/policy_config.py:72\u001b[0m, in \u001b[0;36mcreate_trained_policy\u001b[0;34m(train_config, checkpoint_dir, repack_transforms, sample_kwargs, default_prompt, norm_stats)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAsset id is required to load norm stats.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 64\u001b[0m norm_stats \u001b[38;5;241m=\u001b[39m _checkpoints\u001b[38;5;241m.\u001b[39mload_norm_stats(checkpoint_dir \u001b[38;5;241m/\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massets\u001b[39m\u001b[38;5;124m\"\u001b[39m, data_config\u001b[38;5;241m.\u001b[39masset_id)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _policy\u001b[38;5;241m.\u001b[39mPolicy(\n\u001b[1;32m 67\u001b[0m model,\n\u001b[1;32m 68\u001b[0m transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 69\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 70\u001b[0m transforms\u001b[38;5;241m.\u001b[39mInjectDefaultPrompt(default_prompt),\n\u001b[1;32m 71\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[0;32m---> 72\u001b[0m \u001b[43mtransforms\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mNormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnorm_stats\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_quantiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_quantile_norm\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 73\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 74\u001b[0m ],\n\u001b[1;32m 75\u001b[0m output_transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 76\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 77\u001b[0m transforms\u001b[38;5;241m.\u001b[39mUnnormalize(norm_stats, use_quantiles\u001b[38;5;241m=\u001b[39mdata_config\u001b[38;5;241m.\u001b[39muse_quantile_norm),\n\u001b[1;32m 78\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 79\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 80\u001b[0m ],\n\u001b[1;32m 81\u001b[0m sample_kwargs\u001b[38;5;241m=\u001b[39msample_kwargs,\n\u001b[1;32m 82\u001b[0m metadata\u001b[38;5;241m=\u001b[39mtrain_config\u001b[38;5;241m.\u001b[39mpolicy_metadata,\n\u001b[1;32m 83\u001b[0m )\n",
"File \u001b[0;32m<string>:6\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, norm_stats, use_quantiles, strict)\u001b[0m\n",
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:124\u001b[0m, in \u001b[0;36mNormalize.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__post_init__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm_stats \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_quantiles:\n\u001b[0;32m--> 124\u001b[0m \u001b[43m_assert_quantile_stats\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_stats\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:431\u001b[0m, in \u001b[0;36m_assert_quantile_stats\u001b[0;34m(norm_stats)\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m flatten_dict(norm_stats)\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq01 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq99 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 431\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 432\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantile stats must be provided if use_quantile_norm is True. Key \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is missing q01 or q99.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 433\u001b[0m )\n",
"\u001b[0;31mValueError\u001b[0m: quantile stats must be provided if use_quantile_norm is True. Key actions is missing q01 or q99."
]
}
],
"source": [
"\n",
"config = _config.get_config(\"pi0_fast_droid\")\n",
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
"# checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_base\")\n",
"\n",
"# Create a trained policy.\n",
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
@@ -129,7 +181,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.11.12"
}
},
"nbformat": 4,