{ "cells": [ { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from pprint import pprint\n", "import pickle\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "original_batch_file = \"/home/thomwolf/Documents/Github/ACT/batch_save.pt\"\n", "data = torch.load(original_batch_file)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "#orig: image_data, qpos_data, action_data, is_pad\n", "#target: ['observation.images.front', 'observation.images.top', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'action_is_pad']" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "conv = {}\n", "conv['observation.images.front'] = data[0][:, 0]\n", "conv['observation.images.top'] = data[0][:, 1]\n", "conv['observation.state'] = data[1]\n", "conv['action'] = data[2]\n", "conv['episode_index'] = np.zeros(data[0].shape[0])\n", "conv['frame_index'] = np.zeros(data[0].shape[0])\n", "conv['timestamp'] = np.zeros(data[0].shape[0])\n", "conv['next.done'] = np.zeros(data[0].shape[0])\n", "conv['index'] = np.arange(data[0].shape[0])\n", "conv['action_is_pad'] = data[3]" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "torch.save(conv, \"/home/thomwolf/Documents/Github/ACT/batch_save_converted.pt\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lerobot", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.1.-1" } }, "nbformat": 4, "nbformat_minor": 2 }