Files
lerobot/examples/real_robot_example/convert_original_act_batch..ipynb
Thomas Wolf c108bfe840 save state
2024-06-12 20:46:26 +02:00

93 lines
2.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pprint import pprint\n",
"import pickle\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"original_batch_file = \"/home/thomwolf/Documents/Github/ACT/batch_save.pt\"\n",
"data = torch.load(original_batch_file)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"#orig: image_data, qpos_data, action_data, is_pad\n",
"#target: ['observation.images.front', 'observation.images.top', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'action_is_pad']"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"conv = {}\n",
"conv['observation.images.front'] = data[0][:, 0]\n",
"conv['observation.images.top'] = data[0][:, 1]\n",
"conv['observation.state'] = data[1]\n",
"conv['action'] = data[2]\n",
"conv['episode_index'] = np.zeros(data[0].shape[0])\n",
"conv['frame_index'] = np.zeros(data[0].shape[0])\n",
"conv['timestamp'] = np.zeros(data[0].shape[0])\n",
"conv['next.done'] = np.zeros(data[0].shape[0])\n",
"conv['index'] = np.arange(data[0].shape[0])\n",
"conv['action_is_pad'] = data[3]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"torch.save(conv, \"/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.-1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}