160 lines
5.8 KiB
Python
Executable File
160 lines
5.8 KiB
Python
Executable File
#coding=utf-8
|
|
import os
|
|
import numpy as np
|
|
import cv2
|
|
import h5py
|
|
import argparse
|
|
import matplotlib.pyplot as plt
|
|
|
|
DT = 0.02
|
|
# JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
|
JOINT_NAMES = ["joint0", "joint1", "joint2", "joint3", "joint4", "joint5"]
|
|
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
|
BASE_STATE_NAMES = ["linear_vel", "angular_vel"]
|
|
|
|
def load_hdf5(dataset_dir, dataset_name):
|
|
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
|
if not os.path.isfile(dataset_path):
|
|
print(f'Dataset does not exist at \n{dataset_path}\n')
|
|
exit()
|
|
|
|
with h5py.File(dataset_path, 'r') as root:
|
|
is_sim = root.attrs['sim']
|
|
compressed = root.attrs.get('compress', False)
|
|
qpos = root['/observations/qpos'][()]
|
|
qvel = root['/observations/qvel'][()]
|
|
if 'effort' in root.keys():
|
|
effort = root['/observations/effort'][()]
|
|
else:
|
|
effort = None
|
|
action = root['/action'][()]
|
|
base_action = root['/base_action'][()]
|
|
image_dict = dict()
|
|
for cam_name in root[f'/observations/images/'].keys():
|
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
|
if compressed:
|
|
compress_len = root['/compress_len'][()]
|
|
|
|
if compressed:
|
|
for cam_id, cam_name in enumerate(image_dict.keys()):
|
|
# un-pad and uncompress
|
|
padded_compressed_image_list = image_dict[cam_name]
|
|
image_list = []
|
|
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
|
|
image_len = int(compress_len[cam_id, frame_id])
|
|
compressed_image = padded_compressed_image
|
|
image = cv2.imdecode(compressed_image, 1)
|
|
image_list.append(image)
|
|
image_dict[cam_name] = image_list
|
|
|
|
return qpos, qvel, effort, action, base_action, image_dict
|
|
|
|
def main(args):
|
|
dataset_dir = args['dataset_dir']
|
|
episode_idx = args['episode_idx']
|
|
task_name = args['task_name']
|
|
dataset_name = f'episode_{episode_idx}'
|
|
|
|
qpos, qvel, effort, action, base_action, image_dict = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
|
|
|
|
print('hdf5 loaded!!')
|
|
|
|
save_videos(image_dict, action, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
|
|
|
|
|
|
|
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
|
visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png'))
|
|
|
|
def save_videos(video, actions, dt, video_path=None):
|
|
cam_names = list(video.keys())
|
|
all_cam_videos = []
|
|
for cam_name in cam_names:
|
|
all_cam_videos.append(video[cam_name])
|
|
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
|
|
|
n_frames, h, w, _ = all_cam_videos.shape
|
|
fps = int(1 / dt)
|
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
|
for t in range(n_frames):
|
|
image = all_cam_videos[t]
|
|
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
|
cv2.imshow("images",image)
|
|
cv2.waitKey(30)
|
|
print("episode_id: ", t, "left: ", np.round(actions[t][:7], 3), "right: ", np.round(actions[t][7:], 3), "\n")
|
|
out.write(image)
|
|
out.release()
|
|
print(f'Saved video to: {video_path}')
|
|
|
|
|
|
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
|
if label_overwrite:
|
|
label1, label2 = label_overwrite
|
|
else:
|
|
label1, label2 = 'State', 'Command'
|
|
|
|
qpos = np.array(qpos_list) # ts, dim
|
|
command = np.array(command_list)
|
|
|
|
num_ts, num_dim = qpos.shape
|
|
h, w = 2, num_dim
|
|
num_figs = num_dim
|
|
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
|
|
|
|
# plot joint state
|
|
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
|
for dim_idx in range(num_dim):
|
|
ax = axs[dim_idx]
|
|
ax.plot(qpos[:, dim_idx], label=label1, color='orangered')
|
|
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
|
ax.legend()
|
|
|
|
# plot arm command
|
|
# for dim_idx in range(num_dim):
|
|
# ax = axs[dim_idx]
|
|
# ax.plot(command[:, dim_idx], label=label2)
|
|
# ax.legend()
|
|
|
|
if ylim:
|
|
for dim_idx in range(num_dim):
|
|
ax = axs[dim_idx]
|
|
ax.set_ylim(ylim)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(plot_path)
|
|
print(f'Saved qpos plot to: {plot_path}')
|
|
plt.close()
|
|
|
|
|
|
def visualize_base(readings, plot_path=None):
|
|
readings = np.array(readings) # ts, dim
|
|
num_ts, num_dim = readings.shape
|
|
num_figs = num_dim
|
|
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
|
|
|
|
# plot joint state
|
|
all_names = BASE_STATE_NAMES
|
|
for dim_idx in range(num_dim):
|
|
ax = axs[dim_idx]
|
|
ax.plot(readings[:, dim_idx], label='raw')
|
|
ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20')
|
|
ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10')
|
|
ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5')
|
|
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
|
ax.legend()
|
|
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(plot_path)
|
|
print(f'Saved effort plot to: {plot_path}')
|
|
plt.close()
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
|
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
|
|
default="aloha_mobile_dummy", required=False)
|
|
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
|
|
|
|
main(vars(parser.parse_args()))
|