save state

This commit is contained in:
Thomas Wolf
2024-06-12 20:46:26 +02:00
parent a7c030076f
commit c108bfe840
6 changed files with 385 additions and 39 deletions

View File

@@ -0,0 +1,14 @@
inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt')
conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt')
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt')
print((out-d).abs().max())
tensor(0.0044, device='cuda:0', grad_fn=<MaxBackward1>)
inp = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_inp.pt').to('cpu')
conv = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_conv.pt').to('cpu')
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
d = torch.load('/home/thomwolf/Documents/Github/ACT/tensor_out.pt')
print((out-d).abs().max())
tensor(0., grad_fn=<MaxBackward1>)
out = torch.nn.functional.conv2d(inp, conv, bias=None, stride=1, padding=1, dilation=1, groups=1)
torch.save(out, '/home/thomwolf/Documents/Github/ACT/tensor_out_lerobot.pt')

View File

@@ -0,0 +1,92 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pprint import pprint\n",
"import pickle\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"original_batch_file = \"/home/thomwolf/Documents/Github/ACT/batch_save.pt\"\n",
"data = torch.load(original_batch_file)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"#orig: image_data, qpos_data, action_data, is_pad\n",
"#target: ['observation.images.front', 'observation.images.top', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'action_is_pad']"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"conv = {}\n",
"conv['observation.images.front'] = data[0][:, 0]\n",
"conv['observation.images.top'] = data[0][:, 1]\n",
"conv['observation.state'] = data[1]\n",
"conv['action'] = data[2]\n",
"conv['episode_index'] = np.zeros(data[0].shape[0])\n",
"conv['frame_index'] = np.zeros(data[0].shape[0])\n",
"conv['timestamp'] = np.zeros(data[0].shape[0])\n",
"conv['next.done'] = np.zeros(data[0].shape[0])\n",
"conv['index'] = np.arange(data[0].shape[0])\n",
"conv['action_is_pad'] = data[3]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"torch.save(conv, \"/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.-1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long