From c108bfe8407fee2bd442ccf500eed8343f61d011 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Wed, 12 Jun 2024 20:46:26 +0200 Subject: [PATCH] save state --- examples/real_robot_example/DEBUG_ME.md | 14 ++ .../convert_original_act_batch..ipynb | 92 +++++++ .../convert_original_act_checkpoint.ipynb | 234 +++++++++++++++--- lerobot/common/policies/act/modeling_act.py | 26 +- lerobot/common/policies/normalize.py | 4 +- lerobot/scripts/train.py | 54 ++++ 6 files changed, 385 insertions(+), 39 deletions(-) create mode 100644 examples/real_robot_example/DEBUG_ME.md create mode 100644 examples/real_robot_example/convert_original_act_batch..ipynb diff --git a/examples/real_robot_example/DEBUG_ME.md b/examples/real_robot_example/DEBUG_ME.md new file mode 100644 index 000000000..986239917 --- /dev/null +++ b/examples/real_robot_example/DEBUG_ME.md @@ -0,0 +1,14 @@ +inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt') +conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt') +out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1) +d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt') +print((out-d).abs().max()) +tensor(0.0044, device='cuda:0', grad_fn=) +inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt').to('cpu') +conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt').to('cpu') +out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1) +d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt') +print((out-d).abs().max()) +tensor(0., grad_fn=) +out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1) +torch.save(out, '/home/thomwolf/Documents/Github/ACT/tensor_out_lerobot.pt') diff --git a/examples/real_robot_example/convert_original_act_batch..ipynb b/examples/real_robot_example/convert_original_act_batch..ipynb new file mode 100644 index 000000000..82ee3f1d7 --- /dev/null +++ b/examples/real_robot_example/convert_original_act_batch..ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from pprint import pprint\n", + "import pickle\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "original_batch_file = \"/home/thomwolf/Documents/Github/ACT/batch_save.pt\"\n", + "data = torch.load(original_batch_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "#orig: image_data, qpos_data, action_data, is_pad\n", + "#target: ['observation.images.front', 'observation.images.top', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'action_is_pad']" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "conv = {}\n", + "conv['observation.images.front'] = data[0][:, 0]\n", + "conv['observation.images.top'] = data[0][:, 1]\n", + "conv['observation.state'] = data[1]\n", + "conv['action'] = data[2]\n", + "conv['episode_index'] = np.zeros(data[0].shape[0])\n", + "conv['frame_index'] = np.zeros(data[0].shape[0])\n", + "conv['timestamp'] = np.zeros(data[0].shape[0])\n", + "conv['next.done'] = np.zeros(data[0].shape[0])\n", + "conv['index'] = np.arange(data[0].shape[0])\n", + "conv['action_is_pad'] = data[3]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(conv, \"/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lerobot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.1.-1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/real_robot_example/convert_original_act_checkpoint.ipynb b/examples/real_robot_example/convert_original_act_checkpoint.ipynb index 92306b86f..6d22b45d7 100644 --- a/examples/real_robot_example/convert_original_act_checkpoint.ipynb +++ b/examples/real_robot_example/convert_original_act_checkpoint.ipynb @@ -2,23 +2,26 @@ "cells": [ { "cell_type": "code", - "execution_count": 48, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "import torch\n", - "from safetensors.torch import load_file, save_file\n", - "from pprint import pprint" + "from safetensors.torch import load_file, save_file, _remove_duplicate_names\n", + "from pprint import pprint\n", + "import pickle\n", + "import numpy as np" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ - "original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n", - "converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n", + "original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/policy_initial_state.ckpt\"\n", + "stats_path = original_ckpt_path.replace('policy_initial_state.ckpt', f'dataset_stats.pkl')\n", + "converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state/model.safetensors\"\n", "\n", "comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n", "comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n", @@ -28,7 +31,60 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 123, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "with open(stats_path, 'rb') as f:\n", + " stats = pickle.load(f)\n", + "image_stats_mean = torch.tensor([0.485, 0.456, 0.406])\n", + "image_stats_std = torch.tensor([0.229, 0.224, 0.225])" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'action_mean': array([ 0.10564873, -0.06760169, -0.21794759, 1.4932278 , 0.01712005,\n", + " 0.65135974], dtype=float32),\n", + " 'action_std': array([0.324479 , 0.48200357, 0.6450592 , 0.21122734, 0.16231349,\n", + " 0.46708518], dtype=float32),\n", + " 'qpos_mean': array([ 0.10348293, -0.01160635, -0.06582164, 1.490679 , 0.01592216,\n", + " 0.73001933], dtype=float32),\n", + " 'qpos_std': array([0.32309636, 0.4455181 , 0.59022886, 0.17841513, 0.16096668,\n", + " 0.32320786], dtype=float32),\n", + " 'example_qpos': array([[ 0.05212891, -0.58875 , 0.76200193, 1.4289453 , -0.0015332 ,\n", + " 0.9827832 ],\n", + " [ 0.05212891, -0.58875 , 0.76200193, 1.4304786 , -0.0015332 ,\n", + " 0.9827832 ],\n", + " [ 0.05212891, -0.58875 , 0.76200193, 1.4304786 , -0.0015332 ,\n", + " 0.9827832 ],\n", + " ...,\n", + " [-0.03066406, -0.36030275, 0.7190723 , 1.1897656 , -0.18398437,\n", + " 0.9751172 ],\n", + " [-0.03066406, -0.35876954, 0.72367185, 1.1897656 , -0.18398437,\n", + " 0.9751172 ],\n", + " [-0.03066406, -0.35723633, 0.7282715 , 1.1897656 , -0.18398437,\n", + " 0.9751172 ]], dtype=float32)}" + ] + }, + "execution_count": 124, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stats" + ] + }, + { + "cell_type": "code", + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 127, "metadata": {}, "outputs": [ { @@ -305,7 +361,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 128, "metadata": {}, "outputs": [ { @@ -686,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -747,7 +803,27 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 130, + "metadata": {}, + "outputs": [], + "source": [ + "conv['normalize_inputs.buffer_observation_images_front.mean'] = image_stats_mean[:, None, None]\n", + "conv['normalize_inputs.buffer_observation_images_front.std'] = image_stats_std[:, None, None]\n", + "conv['normalize_inputs.buffer_observation_images_top.mean'] = image_stats_mean[:, None, None].clone()\n", + "conv['normalize_inputs.buffer_observation_images_top.std'] = image_stats_std[:, None, None].clone()\n", + "\n", + "conv['normalize_inputs.buffer_observation_state.mean'] = torch.tensor(stats['qpos_mean'])\n", + "conv['normalize_inputs.buffer_observation_state.std'] = torch.tensor(stats['qpos_std'])\n", + "conv['normalize_targets.buffer_action.mean'] = torch.tensor(stats['action_mean'])\n", + "conv['normalize_targets.buffer_action.std'] = torch.tensor(stats['action_std'])\n", + "\n", + "conv['unnormalize_outputs.buffer_action.mean'] = torch.tensor(stats['action_mean'])\n", + "conv['unnormalize_outputs.buffer_action.std'] = torch.tensor(stats['action_std'])" + ] + }, + { + "cell_type": "code", + "execution_count": 131, "metadata": {}, "outputs": [ { @@ -756,7 +832,7 @@ "OrderedDict()" ] }, - "execution_count": 46, + "execution_count": 131, "metadata": {}, "output_type": "execute_result" } @@ -767,36 +843,91 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 132, "metadata": {}, "outputs": [], "source": [ + "not_converted = set(b.keys())\n", "for k, v in conv.items():\n", - " assert b[k].shape == v.shape\n", - " b[k] = v" + " try:\n", + " b[k].shape == v.squeeze().shape\n", + " except Exception as e:\n", + " print(k, v)\n", + " print(b[k].shape)\n", + " print(e)\n", + " b[k] = v\n", + " not_converted.remove(k)" ] }, { "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [], - "source": [ - "save_file(b, converted_ckpt_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 54, + "execution_count": 133, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'" + "set()" ] }, - "execution_count": 54, + "execution_count": 133, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "not_converted" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {})\n" + ] + } + ], + "source": [ + "metadata = None\n", + "to_removes = _remove_duplicate_names(b)\n", + "print(to_removes)" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": {}, + "outputs": [], + "source": [ + "for kept_name, to_remove_group in to_removes.items():\n", + " for to_remove in to_remove_group:\n", + " if metadata is None:\n", + " metadata = {}\n", + "\n", + " if to_remove not in metadata:\n", + " # Do not override user data\n", + " metadata[to_remove] = kept_name\n", + " del b[to_remove]\n", + "save_file(b, converted_ckpt_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state/config.yaml'" + ] + }, + "execution_count": 136, "metadata": {}, "output_type": "execute_result" } @@ -808,6 +939,51 @@ "shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))" ] }, + { + "cell_type": "code", + "execution_count": 137, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['model.action_head.bias', 'model.action_head.weight', 'model.backbone.bn1.bias', 'model.backbone.bn1.running_mean', 'model.backbone.bn1.running_var', 'model.backbone.bn1.weight', 'model.backbone.conv1.weight', 'model.backbone.layer1.0.bn1.bias', 'model.backbone.layer1.0.bn1.running_mean', 'model.backbone.layer1.0.bn1.running_var', 'model.backbone.layer1.0.bn1.weight', 'model.backbone.layer1.0.bn2.bias', 'model.backbone.layer1.0.bn2.running_mean', 'model.backbone.layer1.0.bn2.running_var', 'model.backbone.layer1.0.bn2.weight', 'model.backbone.layer1.0.conv1.weight', 'model.backbone.layer1.0.conv2.weight', 'model.backbone.layer1.1.bn1.bias', 'model.backbone.layer1.1.bn1.running_mean', 'model.backbone.layer1.1.bn1.running_var', 'model.backbone.layer1.1.bn1.weight', 'model.backbone.layer1.1.bn2.bias', 'model.backbone.layer1.1.bn2.running_mean', 'model.backbone.layer1.1.bn2.running_var', 'model.backbone.layer1.1.bn2.weight', 'model.backbone.layer1.1.conv1.weight', 'model.backbone.layer1.1.conv2.weight', 'model.backbone.layer2.0.bn1.bias', 'model.backbone.layer2.0.bn1.running_mean', 'model.backbone.layer2.0.bn1.running_var', 'model.backbone.layer2.0.bn1.weight', 'model.backbone.layer2.0.bn2.bias', 'model.backbone.layer2.0.bn2.running_mean', 'model.backbone.layer2.0.bn2.running_var', 'model.backbone.layer2.0.bn2.weight', 'model.backbone.layer2.0.conv1.weight', 'model.backbone.layer2.0.conv2.weight', 'model.backbone.layer2.0.downsample.0.weight', 'model.backbone.layer2.0.downsample.1.bias', 'model.backbone.layer2.0.downsample.1.running_mean', 'model.backbone.layer2.0.downsample.1.running_var', 'model.backbone.layer2.0.downsample.1.weight', 'model.backbone.layer2.1.bn1.bias', 'model.backbone.layer2.1.bn1.running_mean', 'model.backbone.layer2.1.bn1.running_var', 'model.backbone.layer2.1.bn1.weight', 'model.backbone.layer2.1.bn2.bias', 'model.backbone.layer2.1.bn2.running_mean', 'model.backbone.layer2.1.bn2.running_var', 'model.backbone.layer2.1.bn2.weight', 'model.backbone.layer2.1.conv1.weight', 'model.backbone.layer2.1.conv2.weight', 'model.backbone.layer3.0.bn1.bias', 'model.backbone.layer3.0.bn1.running_mean', 'model.backbone.layer3.0.bn1.running_var', 'model.backbone.layer3.0.bn1.weight', 'model.backbone.layer3.0.bn2.bias', 'model.backbone.layer3.0.bn2.running_mean', 'model.backbone.layer3.0.bn2.running_var', 'model.backbone.layer3.0.bn2.weight', 'model.backbone.layer3.0.conv1.weight', 'model.backbone.layer3.0.conv2.weight', 'model.backbone.layer3.0.downsample.0.weight', 'model.backbone.layer3.0.downsample.1.bias', 'model.backbone.layer3.0.downsample.1.running_mean', 'model.backbone.layer3.0.downsample.1.running_var', 'model.backbone.layer3.0.downsample.1.weight', 'model.backbone.layer3.1.bn1.bias', 'model.backbone.layer3.1.bn1.running_mean', 'model.backbone.layer3.1.bn1.running_var', 'model.backbone.layer3.1.bn1.weight', 'model.backbone.layer3.1.bn2.bias', 'model.backbone.layer3.1.bn2.running_mean', 'model.backbone.layer3.1.bn2.running_var', 'model.backbone.layer3.1.bn2.weight', 'model.backbone.layer3.1.conv1.weight', 'model.backbone.layer3.1.conv2.weight', 'model.backbone.layer4.0.bn1.bias', 'model.backbone.layer4.0.bn1.running_mean', 'model.backbone.layer4.0.bn1.running_var', 'model.backbone.layer4.0.bn1.weight', 'model.backbone.layer4.0.bn2.bias', 'model.backbone.layer4.0.bn2.running_mean', 'model.backbone.layer4.0.bn2.running_var', 'model.backbone.layer4.0.bn2.weight', 'model.backbone.layer4.0.conv1.weight', 'model.backbone.layer4.0.conv2.weight', 'model.backbone.layer4.0.downsample.0.weight', 'model.backbone.layer4.0.downsample.1.bias', 'model.backbone.layer4.0.downsample.1.running_mean', 'model.backbone.layer4.0.downsample.1.running_var', 'model.backbone.layer4.0.downsample.1.weight', 'model.backbone.layer4.1.bn1.bias', 'model.backbone.layer4.1.bn1.running_mean', 'model.backbone.layer4.1.bn1.running_var', 'model.backbone.layer4.1.bn1.weight', 'model.backbone.layer4.1.bn2.bias', 'model.backbone.layer4.1.bn2.running_mean', 'model.backbone.layer4.1.bn2.running_var', 'model.backbone.layer4.1.bn2.weight', 'model.backbone.layer4.1.conv1.weight', 'model.backbone.layer4.1.conv2.weight', 'model.decoder.layers.0.linear1.bias', 'model.decoder.layers.0.linear1.weight', 'model.decoder.layers.0.linear2.bias', 'model.decoder.layers.0.linear2.weight', 'model.decoder.layers.0.multihead_attn.in_proj_bias', 'model.decoder.layers.0.multihead_attn.in_proj_weight', 'model.decoder.layers.0.multihead_attn.out_proj.bias', 'model.decoder.layers.0.multihead_attn.out_proj.weight', 'model.decoder.layers.0.norm1.bias', 'model.decoder.layers.0.norm1.weight', 'model.decoder.layers.0.norm2.bias', 'model.decoder.layers.0.norm2.weight', 'model.decoder.layers.0.norm3.bias', 'model.decoder.layers.0.norm3.weight', 'model.decoder.layers.0.self_attn.in_proj_bias', 'model.decoder.layers.0.self_attn.in_proj_weight', 'model.decoder.layers.0.self_attn.out_proj.bias', 'model.decoder.layers.0.self_attn.out_proj.weight', 'model.decoder_pos_embed.weight', 'model.encoder.layers.0.linear1.bias', 'model.encoder.layers.0.linear1.weight', 'model.encoder.layers.0.linear2.bias', 'model.encoder.layers.0.linear2.weight', 'model.encoder.layers.0.norm1.bias', 'model.encoder.layers.0.norm1.weight', 'model.encoder.layers.0.norm2.bias', 'model.encoder.layers.0.norm2.weight', 'model.encoder.layers.0.self_attn.in_proj_bias', 'model.encoder.layers.0.self_attn.in_proj_weight', 'model.encoder.layers.0.self_attn.out_proj.bias', 'model.encoder.layers.0.self_attn.out_proj.weight', 'model.encoder.layers.1.linear1.bias', 'model.encoder.layers.1.linear1.weight', 'model.encoder.layers.1.linear2.bias', 'model.encoder.layers.1.linear2.weight', 'model.encoder.layers.1.norm1.bias', 'model.encoder.layers.1.norm1.weight', 'model.encoder.layers.1.norm2.bias', 'model.encoder.layers.1.norm2.weight', 'model.encoder.layers.1.self_attn.in_proj_bias', 'model.encoder.layers.1.self_attn.in_proj_weight', 'model.encoder.layers.1.self_attn.out_proj.bias', 'model.encoder.layers.1.self_attn.out_proj.weight', 'model.encoder.layers.2.linear1.bias', 'model.encoder.layers.2.linear1.weight', 'model.encoder.layers.2.linear2.bias', 'model.encoder.layers.2.linear2.weight', 'model.encoder.layers.2.norm1.bias', 'model.encoder.layers.2.norm1.weight', 'model.encoder.layers.2.norm2.bias', 'model.encoder.layers.2.norm2.weight', 'model.encoder.layers.2.self_attn.in_proj_bias', 'model.encoder.layers.2.self_attn.in_proj_weight', 'model.encoder.layers.2.self_attn.out_proj.bias', 'model.encoder.layers.2.self_attn.out_proj.weight', 'model.encoder.layers.3.linear1.bias', 'model.encoder.layers.3.linear1.weight', 'model.encoder.layers.3.linear2.bias', 'model.encoder.layers.3.linear2.weight', 'model.encoder.layers.3.norm1.bias', 'model.encoder.layers.3.norm1.weight', 'model.encoder.layers.3.norm2.bias', 'model.encoder.layers.3.norm2.weight', 'model.encoder.layers.3.self_attn.in_proj_bias', 'model.encoder.layers.3.self_attn.in_proj_weight', 'model.encoder.layers.3.self_attn.out_proj.bias', 'model.encoder.layers.3.self_attn.out_proj.weight', 'model.encoder_img_feat_input_proj.bias', 'model.encoder_img_feat_input_proj.weight', 'model.encoder_latent_input_proj.bias', 'model.encoder_latent_input_proj.weight', 'model.encoder_robot_and_latent_pos_embed.weight', 'model.encoder_robot_state_input_proj.bias', 'model.encoder_robot_state_input_proj.weight', 'model.vae_encoder.layers.0.linear1.bias', 'model.vae_encoder.layers.0.linear1.weight', 'model.vae_encoder.layers.0.linear2.bias', 'model.vae_encoder.layers.0.linear2.weight', 'model.vae_encoder.layers.0.norm1.bias', 'model.vae_encoder.layers.0.norm1.weight', 'model.vae_encoder.layers.0.norm2.bias', 'model.vae_encoder.layers.0.norm2.weight', 'model.vae_encoder.layers.0.self_attn.in_proj_bias', 'model.vae_encoder.layers.0.self_attn.in_proj_weight', 'model.vae_encoder.layers.0.self_attn.out_proj.bias', 'model.vae_encoder.layers.0.self_attn.out_proj.weight', 'model.vae_encoder.layers.1.linear1.bias', 'model.vae_encoder.layers.1.linear1.weight', 'model.vae_encoder.layers.1.linear2.bias', 'model.vae_encoder.layers.1.linear2.weight', 'model.vae_encoder.layers.1.norm1.bias', 'model.vae_encoder.layers.1.norm1.weight', 'model.vae_encoder.layers.1.norm2.bias', 'model.vae_encoder.layers.1.norm2.weight', 'model.vae_encoder.layers.1.self_attn.in_proj_bias', 'model.vae_encoder.layers.1.self_attn.in_proj_weight', 'model.vae_encoder.layers.1.self_attn.out_proj.bias', 'model.vae_encoder.layers.1.self_attn.out_proj.weight', 'model.vae_encoder.layers.2.linear1.bias', 'model.vae_encoder.layers.2.linear1.weight', 'model.vae_encoder.layers.2.linear2.bias', 'model.vae_encoder.layers.2.linear2.weight', 'model.vae_encoder.layers.2.norm1.bias', 'model.vae_encoder.layers.2.norm1.weight', 'model.vae_encoder.layers.2.norm2.bias', 'model.vae_encoder.layers.2.norm2.weight', 'model.vae_encoder.layers.2.self_attn.in_proj_bias', 'model.vae_encoder.layers.2.self_attn.in_proj_weight', 'model.vae_encoder.layers.2.self_attn.out_proj.bias', 'model.vae_encoder.layers.2.self_attn.out_proj.weight', 'model.vae_encoder.layers.3.linear1.bias', 'model.vae_encoder.layers.3.linear1.weight', 'model.vae_encoder.layers.3.linear2.bias', 'model.vae_encoder.layers.3.linear2.weight', 'model.vae_encoder.layers.3.norm1.bias', 'model.vae_encoder.layers.3.norm1.weight', 'model.vae_encoder.layers.3.norm2.bias', 'model.vae_encoder.layers.3.norm2.weight', 'model.vae_encoder.layers.3.self_attn.in_proj_bias', 'model.vae_encoder.layers.3.self_attn.in_proj_weight', 'model.vae_encoder.layers.3.self_attn.out_proj.bias', 'model.vae_encoder.layers.3.self_attn.out_proj.weight', 'model.vae_encoder_action_input_proj.bias', 'model.vae_encoder_action_input_proj.weight', 'model.vae_encoder_cls_embed.weight', 'model.vae_encoder_latent_output_proj.bias', 'model.vae_encoder_latent_output_proj.weight', 'model.vae_encoder_pos_enc', 'model.vae_encoder_robot_state_input_proj.bias', 'model.vae_encoder_robot_state_input_proj.weight', 'normalize_inputs.buffer_observation_images_front.mean', 'normalize_inputs.buffer_observation_images_front.std', 'normalize_inputs.buffer_observation_images_top.mean', 'normalize_inputs.buffer_observation_images_top.std', 'normalize_inputs.buffer_observation_state.mean', 'normalize_inputs.buffer_observation_state.std', 'normalize_targets.buffer_action.mean', 'normalize_targets.buffer_action.std', 'unnormalize_outputs.buffer_action.mean', 'unnormalize_outputs.buffer_action.std'])" + ] + }, + "execution_count": 137, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "c = load_file(converted_ckpt_path)\n", + "c.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[0.4850]],\n", + "\n", + " [[0.4560]],\n", + "\n", + " [[0.4060]]])" + ] + }, + "execution_count": 105, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "c['normalize_inputs.buffer_observation_images_front.mean']" + ] + }, { "cell_type": "code", "execution_count": null, @@ -832,7 +1008,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.1.-1" } }, "nbformat": 4, diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index bbbb512dd..42f7c1200 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -32,7 +32,6 @@ import torchvision from huggingface_hub import PyTorchModelHubMixin from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter -from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize @@ -75,7 +74,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): self.model = ACT(config) - self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.expected_image_keys = [ + k for k in sorted(config.input_shapes) if k.startswith("observation.image") + ] self.reset() @@ -135,7 +136,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch["observation.images"] = torch.stack( + [batch[k] for k in sorted(self.expected_image_keys)], dim=-4 + ) batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) @@ -228,13 +231,18 @@ class ACT(nn.Module): # Backbone for image feature extraction. backbone_model = getattr(torchvision.models, config.vision_backbone)( replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], - weights=config.pretrained_backbone_weights, - norm_layer=FrozenBatchNorm2d, + weights="DEFAULT", # config.pretrained_backbone_weights, + # norm_layer=FrozenBatchNorm2d, ) # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature # map). # Note: The forward method of this returns a dict: {"feature_map": output}. - self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + # TODO thom fix this + # self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + self.backbone = IntermediateLayerGetter( + backbone_model, return_layers={"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + ) # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) @@ -294,7 +302,7 @@ class ACT(nn.Module): batch_size = batch["observation.images"].shape[0] # Prepare the latent for input to the transformer encoder. - if self.config.use_vae and "action" in batch: + if False: ###### TODO(thom) remove this self.config.use_vae and "action" in batch: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size @@ -346,7 +354,9 @@ class ACT(nn.Module): images = batch["observation.images"] for cam_index in range(images.shape[-4]): - cam_features = self.backbone(images[:, cam_index])["feature_map"] + torch.backends.cudnn.deterministic = True + cam_features = self.backbone(images[:, cam_index]) + cam_features = cam_features[3] # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 9b055f7e6..8b2de1387 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -140,14 +140,14 @@ class Normalize(nn.Module): std = buffer["std"] assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) + batch[key].sub_(mean).div_(std) elif mode == "min_max": min = buffer["min"] max = buffer["max"] assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min + 1e-8) + batch[key] = (batch[key] - min) / (max - min) # normalize to [-1, 1] batch[key] = batch[key] * 2 - 1 else: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 3b5b8948b..5ee138b46 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -44,6 +44,12 @@ from lerobot.common.utils.utils import ( ) from lerobot.scripts.eval import eval_policy +################## TODO remove this part +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +torch.use_deterministic_algorithms(True) +################## + def make_optimizer_and_scheduler(cfg, policy): if cfg.policy.name == "act": @@ -104,7 +110,55 @@ def update_policy( start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() + + ################## TODO remove this part + pretrained_policy_name_or_path = ( + "/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state" + ) + from lerobot.common.policies.act.modeling_act import ACTPolicy + + policy_cls = ACTPolicy + policy_cfg = policy.config + policy = policy_cls(policy_cfg) + policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) + policy.to(device) + + policy.eval() # No dropout + ################## + with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + ########################### TODO remove this part + batch = torch.load("/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt", map_location=device) + + # print some stats + def model_stats(model): + na = [n for n, a in model.named_parameters() if "normalize_" not in n] + me = [a.mean().item() for n, a in model.named_parameters() if "normalize_" not in n] + print(na[me.index(min(me))], min(me)) + print(sum(me)) + mi = [a.min().item() for n, a in model.named_parameters() if "normalize_" not in n] + print(na[mi.index(min(mi))], min(mi)) + print(sum(mi)) + ma = [a.max().item() for n, a in model.named_parameters() if "normalize_" not in n] + print(na[ma.index(max(ma))], max(ma)) + print(sum(ma)) + + model_stats(policy) + + def batch_stats(data): + print(min(d.min() for d in data)) + print(max(d.max() for d in data)) + + data = ( + batch["observation.images.front"], + batch["observation.images.top"], + batch["observation.state"], + batch["action"], + ) + batch_stats(data) + + ########################### + output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) loss = output_dict["loss"]