{ "cells": [ { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from safetensors.torch import load_file, save_file, _remove_duplicate_names\n", "from pprint import pprint\n", "import pickle\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [], "source": [ "original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/policy_initial_state.ckpt\"\n", "stats_path = original_ckpt_path.replace('policy_initial_state.ckpt', f'dataset_stats.pkl')\n", "converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state/model.safetensors\"\n", "\n", "comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n", "comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n", "comparison_config_json_path = comparison_main_path + \"config.json\"\n", "comparison_config_yaml_path = comparison_main_path + \"config.yaml\"" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [], "source": [ "\n", "with open(stats_path, 'rb') as f:\n", " stats = pickle.load(f)\n", "image_stats_mean = torch.tensor([0.485, 0.456, 0.406])\n", "image_stats_std = torch.tensor([0.229, 0.224, 0.225])" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'action_mean': array([ 0.10564873, -0.06760169, -0.21794759, 1.4932278 , 0.01712005,\n", " 0.65135974], dtype=float32),\n", " 'action_std': array([0.324479 , 0.48200357, 0.6450592 , 0.21122734, 0.16231349,\n", " 0.46708518], dtype=float32),\n", " 'qpos_mean': array([ 0.10348293, -0.01160635, -0.06582164, 1.490679 , 0.01592216,\n", " 0.73001933], dtype=float32),\n", " 'qpos_std': array([0.32309636, 0.4455181 , 0.59022886, 0.17841513, 0.16096668,\n", " 0.32320786], dtype=float32),\n", " 'example_qpos': array([[ 0.05212891, -0.58875 , 0.76200193, 1.4289453 , -0.0015332 ,\n", " 0.9827832 ],\n", " [ 0.05212891, -0.58875 , 0.76200193, 1.4304786 , -0.0015332 ,\n", " 0.9827832 ],\n", " [ 0.05212891, -0.58875 , 0.76200193, 1.4304786 , -0.0015332 ,\n", " 0.9827832 ],\n", " ...,\n", " [-0.03066406, -0.36030275, 0.7190723 , 1.1897656 , -0.18398437,\n", " 0.9751172 ],\n", " [-0.03066406, -0.35876954, 0.72367185, 1.1897656 , -0.18398437,\n", " 0.9751172 ],\n", " [-0.03066406, -0.35723633, 0.7282715 , 1.1897656 , -0.18398437,\n", " 0.9751172 ]], dtype=float32)}" ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stats" ] }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [], "source": [ "a = torch.load(original_ckpt_path)" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [], "source": [ "b = load_file(comparison_safetensor_path)" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['model.action_head.bias',\n", " 'model.action_head.weight',\n", " 'model.backbone.bn1.bias',\n", " 'model.backbone.bn1.running_mean',\n", " 'model.backbone.bn1.running_var',\n", " 'model.backbone.bn1.weight',\n", " 'model.backbone.conv1.weight',\n", " 'model.backbone.layer1.0.bn1.bias',\n", " 'model.backbone.layer1.0.bn1.running_mean',\n", " 'model.backbone.layer1.0.bn1.running_var',\n", " 'model.backbone.layer1.0.bn1.weight',\n", " 'model.backbone.layer1.0.bn2.bias',\n", " 'model.backbone.layer1.0.bn2.running_mean',\n", " 'model.backbone.layer1.0.bn2.running_var',\n", " 'model.backbone.layer1.0.bn2.weight',\n", " 'model.backbone.layer1.0.conv1.weight',\n", " 'model.backbone.layer1.0.conv2.weight',\n", " 'model.backbone.layer1.1.bn1.bias',\n", " 'model.backbone.layer1.1.bn1.running_mean',\n", " 'model.backbone.layer1.1.bn1.running_var',\n", " 'model.backbone.layer1.1.bn1.weight',\n", " 'model.backbone.layer1.1.bn2.bias',\n", " 'model.backbone.layer1.1.bn2.running_mean',\n", " 'model.backbone.layer1.1.bn2.running_var',\n", " 'model.backbone.layer1.1.bn2.weight',\n", " 'model.backbone.layer1.1.conv1.weight',\n", " 'model.backbone.layer1.1.conv2.weight',\n", " 'model.backbone.layer2.0.bn1.bias',\n", " 'model.backbone.layer2.0.bn1.running_mean',\n", " 'model.backbone.layer2.0.bn1.running_var',\n", " 'model.backbone.layer2.0.bn1.weight',\n", " 'model.backbone.layer2.0.bn2.bias',\n", " 'model.backbone.layer2.0.bn2.running_mean',\n", " 'model.backbone.layer2.0.bn2.running_var',\n", " 'model.backbone.layer2.0.bn2.weight',\n", " 'model.backbone.layer2.0.conv1.weight',\n", " 'model.backbone.layer2.0.conv2.weight',\n", " 'model.backbone.layer2.0.downsample.0.weight',\n", " 'model.backbone.layer2.0.downsample.1.bias',\n", " 'model.backbone.layer2.0.downsample.1.running_mean',\n", " 'model.backbone.layer2.0.downsample.1.running_var',\n", " 'model.backbone.layer2.0.downsample.1.weight',\n", " 'model.backbone.layer2.1.bn1.bias',\n", " 'model.backbone.layer2.1.bn1.running_mean',\n", " 'model.backbone.layer2.1.bn1.running_var',\n", " 'model.backbone.layer2.1.bn1.weight',\n", " 'model.backbone.layer2.1.bn2.bias',\n", " 'model.backbone.layer2.1.bn2.running_mean',\n", " 'model.backbone.layer2.1.bn2.running_var',\n", " 'model.backbone.layer2.1.bn2.weight',\n", " 'model.backbone.layer2.1.conv1.weight',\n", " 'model.backbone.layer2.1.conv2.weight',\n", " 'model.backbone.layer3.0.bn1.bias',\n", " 'model.backbone.layer3.0.bn1.running_mean',\n", " 'model.backbone.layer3.0.bn1.running_var',\n", " 'model.backbone.layer3.0.bn1.weight',\n", " 'model.backbone.layer3.0.bn2.bias',\n", " 'model.backbone.layer3.0.bn2.running_mean',\n", " 'model.backbone.layer3.0.bn2.running_var',\n", " 'model.backbone.layer3.0.bn2.weight',\n", " 'model.backbone.layer3.0.conv1.weight',\n", " 'model.backbone.layer3.0.conv2.weight',\n", " 'model.backbone.layer3.0.downsample.0.weight',\n", " 'model.backbone.layer3.0.downsample.1.bias',\n", " 'model.backbone.layer3.0.downsample.1.running_mean',\n", " 'model.backbone.layer3.0.downsample.1.running_var',\n", " 'model.backbone.layer3.0.downsample.1.weight',\n", " 'model.backbone.layer3.1.bn1.bias',\n", " 'model.backbone.layer3.1.bn1.running_mean',\n", " 'model.backbone.layer3.1.bn1.running_var',\n", " 'model.backbone.layer3.1.bn1.weight',\n", " 'model.backbone.layer3.1.bn2.bias',\n", " 'model.backbone.layer3.1.bn2.running_mean',\n", " 'model.backbone.layer3.1.bn2.running_var',\n", " 'model.backbone.layer3.1.bn2.weight',\n", " 'model.backbone.layer3.1.conv1.weight',\n", " 'model.backbone.layer3.1.conv2.weight',\n", " 'model.backbone.layer4.0.bn1.bias',\n", " 'model.backbone.layer4.0.bn1.running_mean',\n", " 'model.backbone.layer4.0.bn1.running_var',\n", " 'model.backbone.layer4.0.bn1.weight',\n", " 'model.backbone.layer4.0.bn2.bias',\n", " 'model.backbone.layer4.0.bn2.running_mean',\n", " 'model.backbone.layer4.0.bn2.running_var',\n", " 'model.backbone.layer4.0.bn2.weight',\n", " 'model.backbone.layer4.0.conv1.weight',\n", " 'model.backbone.layer4.0.conv2.weight',\n", " 'model.backbone.layer4.0.downsample.0.weight',\n", " 'model.backbone.layer4.0.downsample.1.bias',\n", " 'model.backbone.layer4.0.downsample.1.running_mean',\n", " 'model.backbone.layer4.0.downsample.1.running_var',\n", " 'model.backbone.layer4.0.downsample.1.weight',\n", " 'model.backbone.layer4.1.bn1.bias',\n", " 'model.backbone.layer4.1.bn1.running_mean',\n", " 'model.backbone.layer4.1.bn1.running_var',\n", " 'model.backbone.layer4.1.bn1.weight',\n", " 'model.backbone.layer4.1.bn2.bias',\n", " 'model.backbone.layer4.1.bn2.running_mean',\n", " 'model.backbone.layer4.1.bn2.running_var',\n", " 'model.backbone.layer4.1.bn2.weight',\n", " 'model.backbone.layer4.1.conv1.weight',\n", " 'model.backbone.layer4.1.conv2.weight',\n", " 'model.decoder.layers.0.linear1.bias',\n", " 'model.decoder.layers.0.linear1.weight',\n", " 'model.decoder.layers.0.linear2.bias',\n", " 'model.decoder.layers.0.linear2.weight',\n", " 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n", " 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n", " 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n", " 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n", " 'model.decoder.layers.0.norm1.bias',\n", " 'model.decoder.layers.0.norm1.weight',\n", " 'model.decoder.layers.0.norm2.bias',\n", " 'model.decoder.layers.0.norm2.weight',\n", " 'model.decoder.layers.0.norm3.bias',\n", " 'model.decoder.layers.0.norm3.weight',\n", " 'model.decoder.layers.0.self_attn.in_proj_bias',\n", " 'model.decoder.layers.0.self_attn.in_proj_weight',\n", " 'model.decoder.layers.0.self_attn.out_proj.bias',\n", " 'model.decoder.layers.0.self_attn.out_proj.weight',\n", " 'model.decoder_pos_embed.weight',\n", " 'model.encoder.layers.0.linear1.bias',\n", " 'model.encoder.layers.0.linear1.weight',\n", " 'model.encoder.layers.0.linear2.bias',\n", " 'model.encoder.layers.0.linear2.weight',\n", " 'model.encoder.layers.0.norm1.bias',\n", " 'model.encoder.layers.0.norm1.weight',\n", " 'model.encoder.layers.0.norm2.bias',\n", " 'model.encoder.layers.0.norm2.weight',\n", " 'model.encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.encoder.layers.1.linear1.bias',\n", " 'model.encoder.layers.1.linear1.weight',\n", " 'model.encoder.layers.1.linear2.bias',\n", " 'model.encoder.layers.1.linear2.weight',\n", " 'model.encoder.layers.1.norm1.bias',\n", " 'model.encoder.layers.1.norm1.weight',\n", " 'model.encoder.layers.1.norm2.bias',\n", " 'model.encoder.layers.1.norm2.weight',\n", " 'model.encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.encoder.layers.2.linear1.bias',\n", " 'model.encoder.layers.2.linear1.weight',\n", " 'model.encoder.layers.2.linear2.bias',\n", " 'model.encoder.layers.2.linear2.weight',\n", " 'model.encoder.layers.2.norm1.bias',\n", " 'model.encoder.layers.2.norm1.weight',\n", " 'model.encoder.layers.2.norm2.bias',\n", " 'model.encoder.layers.2.norm2.weight',\n", " 'model.encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.encoder.layers.3.linear1.bias',\n", " 'model.encoder.layers.3.linear1.weight',\n", " 'model.encoder.layers.3.linear2.bias',\n", " 'model.encoder.layers.3.linear2.weight',\n", " 'model.encoder.layers.3.norm1.bias',\n", " 'model.encoder.layers.3.norm1.weight',\n", " 'model.encoder.layers.3.norm2.bias',\n", " 'model.encoder.layers.3.norm2.weight',\n", " 'model.encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.encoder_img_feat_input_proj.bias',\n", " 'model.encoder_img_feat_input_proj.weight',\n", " 'model.encoder_latent_input_proj.bias',\n", " 'model.encoder_latent_input_proj.weight',\n", " 'model.encoder_robot_and_latent_pos_embed.weight',\n", " 'model.encoder_robot_state_input_proj.bias',\n", " 'model.encoder_robot_state_input_proj.weight',\n", " 'model.vae_encoder.layers.0.linear1.bias',\n", " 'model.vae_encoder.layers.0.linear1.weight',\n", " 'model.vae_encoder.layers.0.linear2.bias',\n", " 'model.vae_encoder.layers.0.linear2.weight',\n", " 'model.vae_encoder.layers.0.norm1.bias',\n", " 'model.vae_encoder.layers.0.norm1.weight',\n", " 'model.vae_encoder.layers.0.norm2.bias',\n", " 'model.vae_encoder.layers.0.norm2.weight',\n", " 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.vae_encoder.layers.1.linear1.bias',\n", " 'model.vae_encoder.layers.1.linear1.weight',\n", " 'model.vae_encoder.layers.1.linear2.bias',\n", " 'model.vae_encoder.layers.1.linear2.weight',\n", " 'model.vae_encoder.layers.1.norm1.bias',\n", " 'model.vae_encoder.layers.1.norm1.weight',\n", " 'model.vae_encoder.layers.1.norm2.bias',\n", " 'model.vae_encoder.layers.1.norm2.weight',\n", " 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.vae_encoder.layers.2.linear1.bias',\n", " 'model.vae_encoder.layers.2.linear1.weight',\n", " 'model.vae_encoder.layers.2.linear2.bias',\n", " 'model.vae_encoder.layers.2.linear2.weight',\n", " 'model.vae_encoder.layers.2.norm1.bias',\n", " 'model.vae_encoder.layers.2.norm1.weight',\n", " 'model.vae_encoder.layers.2.norm2.bias',\n", " 'model.vae_encoder.layers.2.norm2.weight',\n", " 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.vae_encoder.layers.3.linear1.bias',\n", " 'model.vae_encoder.layers.3.linear1.weight',\n", " 'model.vae_encoder.layers.3.linear2.bias',\n", " 'model.vae_encoder.layers.3.linear2.weight',\n", " 'model.vae_encoder.layers.3.norm1.bias',\n", " 'model.vae_encoder.layers.3.norm1.weight',\n", " 'model.vae_encoder.layers.3.norm2.bias',\n", " 'model.vae_encoder.layers.3.norm2.weight',\n", " 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.vae_encoder_action_input_proj.bias',\n", " 'model.vae_encoder_action_input_proj.weight',\n", " 'model.vae_encoder_cls_embed.weight',\n", " 'model.vae_encoder_latent_output_proj.bias',\n", " 'model.vae_encoder_latent_output_proj.weight',\n", " 'model.vae_encoder_pos_enc',\n", " 'model.vae_encoder_robot_state_input_proj.bias',\n", " 'model.vae_encoder_robot_state_input_proj.weight',\n", " 'normalize_inputs.buffer_observation_images_front.mean',\n", " 'normalize_inputs.buffer_observation_images_front.std',\n", " 'normalize_inputs.buffer_observation_images_top.mean',\n", " 'normalize_inputs.buffer_observation_images_top.std',\n", " 'normalize_inputs.buffer_observation_state.mean',\n", " 'normalize_inputs.buffer_observation_state.std',\n", " 'normalize_targets.buffer_action.mean',\n", " 'normalize_targets.buffer_action.std',\n", " 'unnormalize_outputs.buffer_action.mean',\n", " 'unnormalize_outputs.buffer_action.std']\n" ] } ], "source": [ "dest = list(b.keys())\n", "pprint(dest)" ] }, { "cell_type": "code", "execution_count": 128, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['model.pos_table',\n", " 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.0.linear1.weight',\n", " 'model.transformer.encoder.layers.0.linear1.bias',\n", " 'model.transformer.encoder.layers.0.linear2.weight',\n", " 'model.transformer.encoder.layers.0.linear2.bias',\n", " 'model.transformer.encoder.layers.0.norm1.weight',\n", " 'model.transformer.encoder.layers.0.norm1.bias',\n", " 'model.transformer.encoder.layers.0.norm2.weight',\n", " 'model.transformer.encoder.layers.0.norm2.bias',\n", " 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.1.linear1.weight',\n", " 'model.transformer.encoder.layers.1.linear1.bias',\n", " 'model.transformer.encoder.layers.1.linear2.weight',\n", " 'model.transformer.encoder.layers.1.linear2.bias',\n", " 'model.transformer.encoder.layers.1.norm1.weight',\n", " 'model.transformer.encoder.layers.1.norm1.bias',\n", " 'model.transformer.encoder.layers.1.norm2.weight',\n", " 'model.transformer.encoder.layers.1.norm2.bias',\n", " 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.2.linear1.weight',\n", " 'model.transformer.encoder.layers.2.linear1.bias',\n", " 'model.transformer.encoder.layers.2.linear2.weight',\n", " 'model.transformer.encoder.layers.2.linear2.bias',\n", " 'model.transformer.encoder.layers.2.norm1.weight',\n", " 'model.transformer.encoder.layers.2.norm1.bias',\n", " 'model.transformer.encoder.layers.2.norm2.weight',\n", " 'model.transformer.encoder.layers.2.norm2.bias',\n", " 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.3.linear1.weight',\n", " 'model.transformer.encoder.layers.3.linear1.bias',\n", " 'model.transformer.encoder.layers.3.linear2.weight',\n", " 'model.transformer.encoder.layers.3.linear2.bias',\n", " 'model.transformer.encoder.layers.3.norm1.weight',\n", " 'model.transformer.encoder.layers.3.norm1.bias',\n", " 'model.transformer.encoder.layers.3.norm2.weight',\n", " 'model.transformer.encoder.layers.3.norm2.bias',\n", " 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.0.linear1.weight',\n", " 'model.transformer.decoder.layers.0.linear1.bias',\n", " 'model.transformer.decoder.layers.0.linear2.weight',\n", " 'model.transformer.decoder.layers.0.linear2.bias',\n", " 'model.transformer.decoder.layers.0.norm1.weight',\n", " 'model.transformer.decoder.layers.0.norm1.bias',\n", " 'model.transformer.decoder.layers.0.norm2.weight',\n", " 'model.transformer.decoder.layers.0.norm2.bias',\n", " 'model.transformer.decoder.layers.0.norm3.weight',\n", " 'model.transformer.decoder.layers.0.norm3.bias',\n", " 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.1.linear1.weight',\n", " 'model.transformer.decoder.layers.1.linear1.bias',\n", " 'model.transformer.decoder.layers.1.linear2.weight',\n", " 'model.transformer.decoder.layers.1.linear2.bias',\n", " 'model.transformer.decoder.layers.1.norm1.weight',\n", " 'model.transformer.decoder.layers.1.norm1.bias',\n", " 'model.transformer.decoder.layers.1.norm2.weight',\n", " 'model.transformer.decoder.layers.1.norm2.bias',\n", " 'model.transformer.decoder.layers.1.norm3.weight',\n", " 'model.transformer.decoder.layers.1.norm3.bias',\n", " 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.2.linear1.weight',\n", " 'model.transformer.decoder.layers.2.linear1.bias',\n", " 'model.transformer.decoder.layers.2.linear2.weight',\n", " 'model.transformer.decoder.layers.2.linear2.bias',\n", " 'model.transformer.decoder.layers.2.norm1.weight',\n", " 'model.transformer.decoder.layers.2.norm1.bias',\n", " 'model.transformer.decoder.layers.2.norm2.weight',\n", " 'model.transformer.decoder.layers.2.norm2.bias',\n", " 'model.transformer.decoder.layers.2.norm3.weight',\n", " 'model.transformer.decoder.layers.2.norm3.bias',\n", " 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.3.linear1.weight',\n", " 'model.transformer.decoder.layers.3.linear1.bias',\n", " 'model.transformer.decoder.layers.3.linear2.weight',\n", " 'model.transformer.decoder.layers.3.linear2.bias',\n", " 'model.transformer.decoder.layers.3.norm1.weight',\n", " 'model.transformer.decoder.layers.3.norm1.bias',\n", " 'model.transformer.decoder.layers.3.norm2.weight',\n", " 'model.transformer.decoder.layers.3.norm2.bias',\n", " 'model.transformer.decoder.layers.3.norm3.weight',\n", " 'model.transformer.decoder.layers.3.norm3.bias',\n", " 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.4.linear1.weight',\n", " 'model.transformer.decoder.layers.4.linear1.bias',\n", " 'model.transformer.decoder.layers.4.linear2.weight',\n", " 'model.transformer.decoder.layers.4.linear2.bias',\n", " 'model.transformer.decoder.layers.4.norm1.weight',\n", " 'model.transformer.decoder.layers.4.norm1.bias',\n", " 'model.transformer.decoder.layers.4.norm2.weight',\n", " 'model.transformer.decoder.layers.4.norm2.bias',\n", " 'model.transformer.decoder.layers.4.norm3.weight',\n", " 'model.transformer.decoder.layers.4.norm3.bias',\n", " 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.5.linear1.weight',\n", " 'model.transformer.decoder.layers.5.linear1.bias',\n", " 'model.transformer.decoder.layers.5.linear2.weight',\n", " 'model.transformer.decoder.layers.5.linear2.bias',\n", " 'model.transformer.decoder.layers.5.norm1.weight',\n", " 'model.transformer.decoder.layers.5.norm1.bias',\n", " 'model.transformer.decoder.layers.5.norm2.weight',\n", " 'model.transformer.decoder.layers.5.norm2.bias',\n", " 'model.transformer.decoder.layers.5.norm3.weight',\n", " 'model.transformer.decoder.layers.5.norm3.bias',\n", " 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.6.linear1.weight',\n", " 'model.transformer.decoder.layers.6.linear1.bias',\n", " 'model.transformer.decoder.layers.6.linear2.weight',\n", " 'model.transformer.decoder.layers.6.linear2.bias',\n", " 'model.transformer.decoder.layers.6.norm1.weight',\n", " 'model.transformer.decoder.layers.6.norm1.bias',\n", " 'model.transformer.decoder.layers.6.norm2.weight',\n", " 'model.transformer.decoder.layers.6.norm2.bias',\n", " 'model.transformer.decoder.layers.6.norm3.weight',\n", " 'model.transformer.decoder.layers.6.norm3.bias',\n", " 'model.transformer.decoder.norm.weight',\n", " 'model.transformer.decoder.norm.bias',\n", " 'model.encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.encoder.layers.0.linear1.weight',\n", " 'model.encoder.layers.0.linear1.bias',\n", " 'model.encoder.layers.0.linear2.weight',\n", " 'model.encoder.layers.0.linear2.bias',\n", " 'model.encoder.layers.0.norm1.weight',\n", " 'model.encoder.layers.0.norm1.bias',\n", " 'model.encoder.layers.0.norm2.weight',\n", " 'model.encoder.layers.0.norm2.bias',\n", " 'model.encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.encoder.layers.1.linear1.weight',\n", " 'model.encoder.layers.1.linear1.bias',\n", " 'model.encoder.layers.1.linear2.weight',\n", " 'model.encoder.layers.1.linear2.bias',\n", " 'model.encoder.layers.1.norm1.weight',\n", " 'model.encoder.layers.1.norm1.bias',\n", " 'model.encoder.layers.1.norm2.weight',\n", " 'model.encoder.layers.1.norm2.bias',\n", " 'model.encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.encoder.layers.2.linear1.weight',\n", " 'model.encoder.layers.2.linear1.bias',\n", " 'model.encoder.layers.2.linear2.weight',\n", " 'model.encoder.layers.2.linear2.bias',\n", " 'model.encoder.layers.2.norm1.weight',\n", " 'model.encoder.layers.2.norm1.bias',\n", " 'model.encoder.layers.2.norm2.weight',\n", " 'model.encoder.layers.2.norm2.bias',\n", " 'model.encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.encoder.layers.3.linear1.weight',\n", " 'model.encoder.layers.3.linear1.bias',\n", " 'model.encoder.layers.3.linear2.weight',\n", " 'model.encoder.layers.3.linear2.bias',\n", " 'model.encoder.layers.3.norm1.weight',\n", " 'model.encoder.layers.3.norm1.bias',\n", " 'model.encoder.layers.3.norm2.weight',\n", " 'model.encoder.layers.3.norm2.bias',\n", " 'model.action_head.weight',\n", " 'model.action_head.bias',\n", " 'model.is_pad_head.weight',\n", " 'model.is_pad_head.bias',\n", " 'model.query_embed.weight',\n", " 'model.input_proj.weight',\n", " 'model.input_proj.bias',\n", " 'model.backbones.0.0.body.conv1.weight',\n", " 'model.backbones.0.0.body.bn1.weight',\n", " 'model.backbones.0.0.body.bn1.bias',\n", " 'model.backbones.0.0.body.bn1.running_mean',\n", " 'model.backbones.0.0.body.bn1.running_var',\n", " 'model.backbones.0.0.body.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n", " 'model.input_proj_robot_state.weight',\n", " 'model.input_proj_robot_state.bias',\n", " 'model.cls_embed.weight',\n", " 'model.encoder_action_proj.weight',\n", " 'model.encoder_action_proj.bias',\n", " 'model.encoder_joint_proj.weight',\n", " 'model.encoder_joint_proj.bias',\n", " 'model.latent_proj.weight',\n", " 'model.latent_proj.bias',\n", " 'model.latent_out_proj.weight',\n", " 'model.latent_out_proj.bias',\n", " 'model.additional_pos_embed.weight']\n" ] } ], "source": [ "orig = list(a.keys())\n", "pprint(orig)" ] }, { "cell_type": "code", "execution_count": 129, "metadata": {}, "outputs": [], "source": [ "a = torch.load(original_ckpt_path)\n", "\n", "to_remove_startswith = ['model.transformer.decoder.layers.1.',\n", " 'model.transformer.decoder.layers.2.',\n", " 'model.transformer.decoder.layers.3.',\n", " 'model.transformer.decoder.layers.4.',\n", " 'model.transformer.decoder.layers.5.',\n", " 'model.transformer.decoder.layers.6.',\n", " 'model.is_pad_head']\n", "\n", "to_remove_in = ['num_batches_tracked',]\n", "\n", "conv = {}\n", "\n", "keys = list(a.keys())\n", "for k in keys:\n", " if any(k.startswith(tr) for tr in to_remove_startswith):\n", " a.pop(k)\n", " continue\n", " if any(tr in k for tr in to_remove_in):\n", " a.pop(k)\n", " continue\n", " if k.startswith('model.transformer.encoder.layers.'):\n", " conv[k.replace('transformer.', '')] = a.pop(k)\n", " if k.startswith('model.transformer.decoder.layers.0.'):\n", " conv[k.replace('transformer.', '')] = a.pop(k)\n", " if k.startswith('model.transformer.decoder.norm.'):\n", " conv[k.replace('transformer.', '')] = a.pop(k)\n", " if k.startswith('model.encoder.layers.'):\n", " conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n", " if k.startswith('model.action_head.'):\n", " conv[k] = a.pop(k)\n", " if k.startswith('model.pos_table'):\n", " conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n", " if k.startswith('model.query_embed.'):\n", " conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n", " if k.startswith('model.input_proj.'):\n", " conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n", " if k.startswith('model.input_proj_robot_state.'):\n", " conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n", " if k.startswith('model.backbones.0.0.body.'):\n", " conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n", " if k.startswith('model.cls_embed.'):\n", " conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n", " if k.startswith('model.encoder_action_proj.'):\n", " conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n", " if k.startswith('model.encoder_joint_proj.'):\n", " conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n", " if k.startswith('model.latent_proj.'):\n", " conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n", " if k.startswith('model.latent_out_proj.'):\n", " conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n", " if k.startswith('model.additional_pos_embed.'):\n", " conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)" ] }, { "cell_type": "code", "execution_count": 130, "metadata": {}, "outputs": [], "source": [ "conv['normalize_inputs.buffer_observation_images_front.mean'] = image_stats_mean[:, None, None]\n", "conv['normalize_inputs.buffer_observation_images_front.std'] = image_stats_std[:, None, None]\n", "conv['normalize_inputs.buffer_observation_images_top.mean'] = image_stats_mean[:, None, None].clone()\n", "conv['normalize_inputs.buffer_observation_images_top.std'] = image_stats_std[:, None, None].clone()\n", "\n", "conv['normalize_inputs.buffer_observation_state.mean'] = torch.tensor(stats['qpos_mean'])\n", "conv['normalize_inputs.buffer_observation_state.std'] = torch.tensor(stats['qpos_std'])\n", "conv['normalize_targets.buffer_action.mean'] = torch.tensor(stats['action_mean'])\n", "conv['normalize_targets.buffer_action.std'] = torch.tensor(stats['action_std'])\n", "\n", "conv['unnormalize_outputs.buffer_action.mean'] = torch.tensor(stats['action_mean'])\n", "conv['unnormalize_outputs.buffer_action.std'] = torch.tensor(stats['action_std'])" ] }, { "cell_type": "code", "execution_count": 131, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict()" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a" ] }, { "cell_type": "code", "execution_count": 132, "metadata": {}, "outputs": [], "source": [ "not_converted = set(b.keys())\n", "for k, v in conv.items():\n", " try:\n", " b[k].shape == v.squeeze().shape\n", " except Exception as e:\n", " print(k, v)\n", " print(b[k].shape)\n", " print(e)\n", " b[k] = v\n", " not_converted.remove(k)" ] }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "set()" ] }, "execution_count": 133, "metadata": {}, "output_type": "execute_result" } ], "source": [ "not_converted" ] }, { "cell_type": "code", "execution_count": 134, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "defaultdict(, {})\n" ] } ], "source": [ "metadata = None\n", "to_removes = _remove_duplicate_names(b)\n", "print(to_removes)" ] }, { "cell_type": "code", "execution_count": 135, "metadata": {}, "outputs": [], "source": [ "for kept_name, to_remove_group in to_removes.items():\n", " for to_remove in to_remove_group:\n", " if metadata is None:\n", " metadata = {}\n", "\n", " if to_remove not in metadata:\n", " # Do not override user data\n", " metadata[to_remove] = kept_name\n", " del b[to_remove]\n", "save_file(b, converted_ckpt_path)" ] }, { "cell_type": "code", "execution_count": 136, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort_raw/initial_state/config.yaml'" ] }, "execution_count": 136, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Now also copy the config files\n", "import shutil\n", "shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n", "shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))" ] }, { "cell_type": "code", "execution_count": 137, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['model.action_head.bias', 'model.action_head.weight', 'model.backbone.bn1.bias', 'model.backbone.bn1.running_mean', 'model.backbone.bn1.running_var', 'model.backbone.bn1.weight', 'model.backbone.conv1.weight', 'model.backbone.layer1.0.bn1.bias', 'model.backbone.layer1.0.bn1.running_mean', 'model.backbone.layer1.0.bn1.running_var', 'model.backbone.layer1.0.bn1.weight', 'model.backbone.layer1.0.bn2.bias', 'model.backbone.layer1.0.bn2.running_mean', 'model.backbone.layer1.0.bn2.running_var', 'model.backbone.layer1.0.bn2.weight', 'model.backbone.layer1.0.conv1.weight', 'model.backbone.layer1.0.conv2.weight', 'model.backbone.layer1.1.bn1.bias', 'model.backbone.layer1.1.bn1.running_mean', 'model.backbone.layer1.1.bn1.running_var', 'model.backbone.layer1.1.bn1.weight', 'model.backbone.layer1.1.bn2.bias', 'model.backbone.layer1.1.bn2.running_mean', 'model.backbone.layer1.1.bn2.running_var', 'model.backbone.layer1.1.bn2.weight', 'model.backbone.layer1.1.conv1.weight', 'model.backbone.layer1.1.conv2.weight', 'model.backbone.layer2.0.bn1.bias', 'model.backbone.layer2.0.bn1.running_mean', 'model.backbone.layer2.0.bn1.running_var', 'model.backbone.layer2.0.bn1.weight', 'model.backbone.layer2.0.bn2.bias', 'model.backbone.layer2.0.bn2.running_mean', 'model.backbone.layer2.0.bn2.running_var', 'model.backbone.layer2.0.bn2.weight', 'model.backbone.layer2.0.conv1.weight', 'model.backbone.layer2.0.conv2.weight', 'model.backbone.layer2.0.downsample.0.weight', 'model.backbone.layer2.0.downsample.1.bias', 'model.backbone.layer2.0.downsample.1.running_mean', 'model.backbone.layer2.0.downsample.1.running_var', 'model.backbone.layer2.0.downsample.1.weight', 'model.backbone.layer2.1.bn1.bias', 'model.backbone.layer2.1.bn1.running_mean', 'model.backbone.layer2.1.bn1.running_var', 'model.backbone.layer2.1.bn1.weight', 'model.backbone.layer2.1.bn2.bias', 'model.backbone.layer2.1.bn2.running_mean', 'model.backbone.layer2.1.bn2.running_var', 'model.backbone.layer2.1.bn2.weight', 'model.backbone.layer2.1.conv1.weight', 'model.backbone.layer2.1.conv2.weight', 'model.backbone.layer3.0.bn1.bias', 'model.backbone.layer3.0.bn1.running_mean', 'model.backbone.layer3.0.bn1.running_var', 'model.backbone.layer3.0.bn1.weight', 'model.backbone.layer3.0.bn2.bias', 'model.backbone.layer3.0.bn2.running_mean', 'model.backbone.layer3.0.bn2.running_var', 'model.backbone.layer3.0.bn2.weight', 'model.backbone.layer3.0.conv1.weight', 'model.backbone.layer3.0.conv2.weight', 'model.backbone.layer3.0.downsample.0.weight', 'model.backbone.layer3.0.downsample.1.bias', 'model.backbone.layer3.0.downsample.1.running_mean', 'model.backbone.layer3.0.downsample.1.running_var', 'model.backbone.layer3.0.downsample.1.weight', 'model.backbone.layer3.1.bn1.bias', 'model.backbone.layer3.1.bn1.running_mean', 'model.backbone.layer3.1.bn1.running_var', 'model.backbone.layer3.1.bn1.weight', 'model.backbone.layer3.1.bn2.bias', 'model.backbone.layer3.1.bn2.running_mean', 'model.backbone.layer3.1.bn2.running_var', 'model.backbone.layer3.1.bn2.weight', 'model.backbone.layer3.1.conv1.weight', 'model.backbone.layer3.1.conv2.weight', 'model.backbone.layer4.0.bn1.bias', 'model.backbone.layer4.0.bn1.running_mean', 'model.backbone.layer4.0.bn1.running_var', 'model.backbone.layer4.0.bn1.weight', 'model.backbone.layer4.0.bn2.bias', 'model.backbone.layer4.0.bn2.running_mean', 'model.backbone.layer4.0.bn2.running_var', 'model.backbone.layer4.0.bn2.weight', 'model.backbone.layer4.0.conv1.weight', 'model.backbone.layer4.0.conv2.weight', 'model.backbone.layer4.0.downsample.0.weight', 'model.backbone.layer4.0.downsample.1.bias', 'model.backbone.layer4.0.downsample.1.running_mean', 'model.backbone.layer4.0.downsample.1.running_var', 'model.backbone.layer4.0.downsample.1.weight', 'model.backbone.layer4.1.bn1.bias', 'model.backbone.layer4.1.bn1.running_mean', 'model.backbone.layer4.1.bn1.running_var', 'model.backbone.layer4.1.bn1.weight', 'model.backbone.layer4.1.bn2.bias', 'model.backbone.layer4.1.bn2.running_mean', 'model.backbone.layer4.1.bn2.running_var', 'model.backbone.layer4.1.bn2.weight', 'model.backbone.layer4.1.conv1.weight', 'model.backbone.layer4.1.conv2.weight', 'model.decoder.layers.0.linear1.bias', 'model.decoder.layers.0.linear1.weight', 'model.decoder.layers.0.linear2.bias', 'model.decoder.layers.0.linear2.weight', 'model.decoder.layers.0.multihead_attn.in_proj_bias', 'model.decoder.layers.0.multihead_attn.in_proj_weight', 'model.decoder.layers.0.multihead_attn.out_proj.bias', 'model.decoder.layers.0.multihead_attn.out_proj.weight', 'model.decoder.layers.0.norm1.bias', 'model.decoder.layers.0.norm1.weight', 'model.decoder.layers.0.norm2.bias', 'model.decoder.layers.0.norm2.weight', 'model.decoder.layers.0.norm3.bias', 'model.decoder.layers.0.norm3.weight', 'model.decoder.layers.0.self_attn.in_proj_bias', 'model.decoder.layers.0.self_attn.in_proj_weight', 'model.decoder.layers.0.self_attn.out_proj.bias', 'model.decoder.layers.0.self_attn.out_proj.weight', 'model.decoder_pos_embed.weight', 'model.encoder.layers.0.linear1.bias', 'model.encoder.layers.0.linear1.weight', 'model.encoder.layers.0.linear2.bias', 'model.encoder.layers.0.linear2.weight', 'model.encoder.layers.0.norm1.bias', 'model.encoder.layers.0.norm1.weight', 'model.encoder.layers.0.norm2.bias', 'model.encoder.layers.0.norm2.weight', 'model.encoder.layers.0.self_attn.in_proj_bias', 'model.encoder.layers.0.self_attn.in_proj_weight', 'model.encoder.layers.0.self_attn.out_proj.bias', 'model.encoder.layers.0.self_attn.out_proj.weight', 'model.encoder.layers.1.linear1.bias', 'model.encoder.layers.1.linear1.weight', 'model.encoder.layers.1.linear2.bias', 'model.encoder.layers.1.linear2.weight', 'model.encoder.layers.1.norm1.bias', 'model.encoder.layers.1.norm1.weight', 'model.encoder.layers.1.norm2.bias', 'model.encoder.layers.1.norm2.weight', 'model.encoder.layers.1.self_attn.in_proj_bias', 'model.encoder.layers.1.self_attn.in_proj_weight', 'model.encoder.layers.1.self_attn.out_proj.bias', 'model.encoder.layers.1.self_attn.out_proj.weight', 'model.encoder.layers.2.linear1.bias', 'model.encoder.layers.2.linear1.weight', 'model.encoder.layers.2.linear2.bias', 'model.encoder.layers.2.linear2.weight', 'model.encoder.layers.2.norm1.bias', 'model.encoder.layers.2.norm1.weight', 'model.encoder.layers.2.norm2.bias', 'model.encoder.layers.2.norm2.weight', 'model.encoder.layers.2.self_attn.in_proj_bias', 'model.encoder.layers.2.self_attn.in_proj_weight', 'model.encoder.layers.2.self_attn.out_proj.bias', 'model.encoder.layers.2.self_attn.out_proj.weight', 'model.encoder.layers.3.linear1.bias', 'model.encoder.layers.3.linear1.weight', 'model.encoder.layers.3.linear2.bias', 'model.encoder.layers.3.linear2.weight', 'model.encoder.layers.3.norm1.bias', 'model.encoder.layers.3.norm1.weight', 'model.encoder.layers.3.norm2.bias', 'model.encoder.layers.3.norm2.weight', 'model.encoder.layers.3.self_attn.in_proj_bias', 'model.encoder.layers.3.self_attn.in_proj_weight', 'model.encoder.layers.3.self_attn.out_proj.bias', 'model.encoder.layers.3.self_attn.out_proj.weight', 'model.encoder_img_feat_input_proj.bias', 'model.encoder_img_feat_input_proj.weight', 'model.encoder_latent_input_proj.bias', 'model.encoder_latent_input_proj.weight', 'model.encoder_robot_and_latent_pos_embed.weight', 'model.encoder_robot_state_input_proj.bias', 'model.encoder_robot_state_input_proj.weight', 'model.vae_encoder.layers.0.linear1.bias', 'model.vae_encoder.layers.0.linear1.weight', 'model.vae_encoder.layers.0.linear2.bias', 'model.vae_encoder.layers.0.linear2.weight', 'model.vae_encoder.layers.0.norm1.bias', 'model.vae_encoder.layers.0.norm1.weight', 'model.vae_encoder.layers.0.norm2.bias', 'model.vae_encoder.layers.0.norm2.weight', 'model.vae_encoder.layers.0.self_attn.in_proj_bias', 'model.vae_encoder.layers.0.self_attn.in_proj_weight', 'model.vae_encoder.layers.0.self_attn.out_proj.bias', 'model.vae_encoder.layers.0.self_attn.out_proj.weight', 'model.vae_encoder.layers.1.linear1.bias', 'model.vae_encoder.layers.1.linear1.weight', 'model.vae_encoder.layers.1.linear2.bias', 'model.vae_encoder.layers.1.linear2.weight', 'model.vae_encoder.layers.1.norm1.bias', 'model.vae_encoder.layers.1.norm1.weight', 'model.vae_encoder.layers.1.norm2.bias', 'model.vae_encoder.layers.1.norm2.weight', 'model.vae_encoder.layers.1.self_attn.in_proj_bias', 'model.vae_encoder.layers.1.self_attn.in_proj_weight', 'model.vae_encoder.layers.1.self_attn.out_proj.bias', 'model.vae_encoder.layers.1.self_attn.out_proj.weight', 'model.vae_encoder.layers.2.linear1.bias', 'model.vae_encoder.layers.2.linear1.weight', 'model.vae_encoder.layers.2.linear2.bias', 'model.vae_encoder.layers.2.linear2.weight', 'model.vae_encoder.layers.2.norm1.bias', 'model.vae_encoder.layers.2.norm1.weight', 'model.vae_encoder.layers.2.norm2.bias', 'model.vae_encoder.layers.2.norm2.weight', 'model.vae_encoder.layers.2.self_attn.in_proj_bias', 'model.vae_encoder.layers.2.self_attn.in_proj_weight', 'model.vae_encoder.layers.2.self_attn.out_proj.bias', 'model.vae_encoder.layers.2.self_attn.out_proj.weight', 'model.vae_encoder.layers.3.linear1.bias', 'model.vae_encoder.layers.3.linear1.weight', 'model.vae_encoder.layers.3.linear2.bias', 'model.vae_encoder.layers.3.linear2.weight', 'model.vae_encoder.layers.3.norm1.bias', 'model.vae_encoder.layers.3.norm1.weight', 'model.vae_encoder.layers.3.norm2.bias', 'model.vae_encoder.layers.3.norm2.weight', 'model.vae_encoder.layers.3.self_attn.in_proj_bias', 'model.vae_encoder.layers.3.self_attn.in_proj_weight', 'model.vae_encoder.layers.3.self_attn.out_proj.bias', 'model.vae_encoder.layers.3.self_attn.out_proj.weight', 'model.vae_encoder_action_input_proj.bias', 'model.vae_encoder_action_input_proj.weight', 'model.vae_encoder_cls_embed.weight', 'model.vae_encoder_latent_output_proj.bias', 'model.vae_encoder_latent_output_proj.weight', 'model.vae_encoder_pos_enc', 'model.vae_encoder_robot_state_input_proj.bias', 'model.vae_encoder_robot_state_input_proj.weight', 'normalize_inputs.buffer_observation_images_front.mean', 'normalize_inputs.buffer_observation_images_front.std', 'normalize_inputs.buffer_observation_images_top.mean', 'normalize_inputs.buffer_observation_images_top.std', 'normalize_inputs.buffer_observation_state.mean', 'normalize_inputs.buffer_observation_state.std', 'normalize_targets.buffer_action.mean', 'normalize_targets.buffer_action.std', 'unnormalize_outputs.buffer_action.mean', 'unnormalize_outputs.buffer_action.std'])" ] }, "execution_count": 137, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = load_file(converted_ckpt_path)\n", "c.keys()" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[0.4850]],\n", "\n", " [[0.4560]],\n", "\n", " [[0.4060]]])" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c['normalize_inputs.buffer_observation_images_front.mean']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lerobot", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }