93 lines
2.2 KiB
Plaintext
93 lines
2.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from pprint import pprint\n",
|
|
"import pickle\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"original_batch_file = \"/home/thomwolf/Documents/Github/ACT/batch_save.pt\"\n",
|
|
"data = torch.load(original_batch_file)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#orig: image_data, qpos_data, action_data, is_pad\n",
|
|
"#target: ['observation.images.front', 'observation.images.top', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'action_is_pad']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"conv = {}\n",
|
|
"conv['observation.images.front'] = data[0][:, 0]\n",
|
|
"conv['observation.images.top'] = data[0][:, 1]\n",
|
|
"conv['observation.state'] = data[1]\n",
|
|
"conv['action'] = data[2]\n",
|
|
"conv['episode_index'] = np.zeros(data[0].shape[0])\n",
|
|
"conv['frame_index'] = np.zeros(data[0].shape[0])\n",
|
|
"conv['timestamp'] = np.zeros(data[0].shape[0])\n",
|
|
"conv['next.done'] = np.zeros(data[0].shape[0])\n",
|
|
"conv['index'] = np.arange(data[0].shape[0])\n",
|
|
"conv['action_is_pad'] = data[3]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.save(conv, \"/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "lerobot",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.1.-1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|