This commit is contained in:
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# ACT: Action Chunking with Transformers
|
||||
|
||||
### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing)
|
||||
TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus.
|
||||
|
||||
#### Project Website: https://tonyzhaozh.github.io/aloha/
|
||||
|
||||
This repo contains the implementation of ACT, together with 2 simulated environments:
|
||||
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real.
|
||||
For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha).
|
||||
|
||||
### Updates:
|
||||
You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link).
|
||||
|
||||
|
||||
### Repo Structure
|
||||
- ``imitate_episodes.py`` Train and Evaluate ACT
|
||||
- ``policy.py`` An adaptor for ACT policy
|
||||
- ``detr`` Model definitions of ACT, modified from DETR
|
||||
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
|
||||
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
|
||||
- ``scripted_policy.py`` Scripted policies for sim environments
|
||||
- ``constants.py`` Constants shared across files
|
||||
- ``utils.py`` Utils such as data loading and helper functions
|
||||
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
conda create -n aloha python=3.8.10
|
||||
conda activate aloha
|
||||
pip install torchvision
|
||||
pip install torch
|
||||
pip install pyquaternion
|
||||
pip install pyyaml
|
||||
pip install rospkg
|
||||
pip install pexpect
|
||||
pip install mujoco==2.3.7
|
||||
pip install dm_control==1.0.14
|
||||
pip install opencv-python
|
||||
pip install matplotlib
|
||||
pip install einops
|
||||
pip install packaging
|
||||
pip install h5py
|
||||
pip install ipython
|
||||
cd act/detr && pip install -e .
|
||||
|
||||
### Example Usages
|
||||
|
||||
To set up a new terminal, run:
|
||||
|
||||
conda activate aloha
|
||||
cd <path to act repo>
|
||||
|
||||
### Simulated experiments
|
||||
|
||||
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
|
||||
To generated 50 episodes of scripted data, run:
|
||||
|
||||
python3 record_sim_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--dataset_dir <data save dir> \
|
||||
--num_episodes 50
|
||||
|
||||
To can add the flag ``--onscreen_render`` to see real-time rendering.
|
||||
To visualize the episode after it is collected, run
|
||||
|
||||
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
|
||||
|
||||
To train ACT:
|
||||
|
||||
# Transfer Cube task
|
||||
python3 imitate_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--ckpt_dir <ckpt dir> \
|
||||
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
|
||||
--num_epochs 2000 --lr 1e-5 \
|
||||
--seed 0
|
||||
|
||||
|
||||
To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
|
||||
The success rate should be around 90% for transfer cube, and around 50% for insertion.
|
||||
To enable temporal ensembling, add flag ``--temporal_agg``.
|
||||
Videos will be saved to ``<ckpt_dir>`` for each rollout.
|
||||
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
|
||||
|
||||
For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued.
|
||||
Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info.
|
||||
|
||||
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
@@ -0,0 +1,74 @@
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '215222076892',
|
||||
bottom_camera: '215222076981',
|
||||
left_camera: '152122078151',
|
||||
right_camera: '152122073489',
|
||||
# init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9],
|
||||
# init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9]
|
||||
init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103'
|
||||
checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data'
|
||||
# checkpoint_name: 'policy_best.ckpt'
|
||||
checkpoint_name: 'policy_9500.ckpt'
|
||||
state_dim: 14
|
||||
save_episode: True
|
||||
num_rollouts: 50 #训练期间要收集的 rollout(轨迹)数量
|
||||
real_robot: True
|
||||
policy_class: 'ACT'
|
||||
onscreen_render: False
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right']
|
||||
episode_len: 300 #episode 的最大长度(时间步数)。
|
||||
task_name: 'aloha_01_11.28'
|
||||
temporal_agg: False #是否使用时间聚合
|
||||
batch_size: 8 #训练期间每批的样本数。
|
||||
seed: 1000 #随机种子。
|
||||
chunk_size: 30 #用于处理序列的块大小
|
||||
eval_every: 1 #每隔 eval_every 步评估一次模型。
|
||||
num_steps: 10000 #训练的总步数。
|
||||
validate_every: 1 #每隔 validate_every 步验证一次模型。
|
||||
save_every: 500 #每隔 save_every 步保存一次检查点。
|
||||
load_pretrain: False #是否加载预训练模型。
|
||||
resume_ckpt_path:
|
||||
name_filter: # TODO
|
||||
skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。
|
||||
stats_dir:
|
||||
sample_weights:
|
||||
train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证)
|
||||
|
||||
policy_config: {
|
||||
hidden_dim: 512, # Size of the embeddings (dimension of the transformer)
|
||||
state_dim: 14, # Dimension of the state
|
||||
position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features
|
||||
lr_backbone: 1.0e-5,
|
||||
masks: False, # If true, the model masks the non-visible pixels
|
||||
backbone: 'resnet18',
|
||||
dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5)
|
||||
dropout: 0.1, # Dropout applied in the transformer
|
||||
nheads: 8,
|
||||
dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks
|
||||
enc_layers: 4, # Number of encoding layers in the transformer
|
||||
dec_layers: 7, # Number of decoding layers in the transformer
|
||||
pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward
|
||||
num_queries: 30,
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'],
|
||||
vq: False,
|
||||
vq_class: none,
|
||||
vq_dim: 64,
|
||||
action_dim: 14,
|
||||
no_encoder: False,
|
||||
lr: 1.0e-5,
|
||||
weight_decay: 1.0e-4,
|
||||
kl_weight: 10,
|
||||
|
||||
# lr_drop: 200,
|
||||
# clip_max_norm: 0.1,
|
||||
}
|
||||
|
||||
|
||||
|
||||
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import numpy as np
|
||||
import collections
|
||||
import os
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_CLOSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def make_ee_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXEETask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
|
||||
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array([
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
])
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
|
||||
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs['gripper_ctrl'] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_act"
|
||||
version = "0.1.0"
|
||||
description = "Embodied data, ACT and other methods; training and verification function packages"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
wandb = ">=0.18.0"
|
||||
einops = ">=0.8.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import h5py
|
||||
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
from sim_env import make_sim_env, BOX_POSE
|
||||
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Generate demonstration data in simulation.
|
||||
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
|
||||
Replace the gripper joint positions with the commanded joint position.
|
||||
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
|
||||
Save this episode of data, and continue to next episode of data collection.
|
||||
"""
|
||||
|
||||
task_name = args['task_name']
|
||||
dataset_dir = args['dataset_dir']
|
||||
num_episodes = args['num_episodes']
|
||||
onscreen_render = args['onscreen_render']
|
||||
inject_noise = False
|
||||
render_cam_name = 'angle'
|
||||
|
||||
if not os.path.isdir(dataset_dir):
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
if task_name == 'sim_transfer_cube_scripted':
|
||||
policy_cls = PickAndTransferPolicy
|
||||
elif task_name == 'sim_insertion_scripted':
|
||||
policy_cls = InsertionPolicy
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
success = []
|
||||
for episode_idx in range(num_episodes):
|
||||
print(f'{episode_idx=}')
|
||||
print('Rollout out EE space scripted policy')
|
||||
# setup the environment
|
||||
env = make_ee_sim_env(task_name)
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
policy = policy_cls(inject_noise)
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.002)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
joint_traj = [ts.observation['qpos'] for ts in episode]
|
||||
# replace gripper pose with gripper control
|
||||
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
|
||||
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
|
||||
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
|
||||
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
|
||||
joint[6] = left_ctrl
|
||||
joint[6+7] = right_ctrl
|
||||
|
||||
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
|
||||
|
||||
# clear unused variables
|
||||
del env
|
||||
del episode
|
||||
del policy
|
||||
|
||||
# setup the environment
|
||||
print('Replaying joint commands')
|
||||
env = make_sim_env(task_name)
|
||||
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
|
||||
ts = env.reset()
|
||||
|
||||
episode_replay = [ts]
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for t in range(len(joint_traj)): # note: this will increase episode length by 1
|
||||
action = joint_traj[t]
|
||||
ts = env.step(action)
|
||||
episode_replay.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.02)
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
success.append(1)
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
success.append(0)
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
plt.close()
|
||||
|
||||
"""
|
||||
For each timestep:
|
||||
observations
|
||||
- images
|
||||
- each_cam_name (480, 640, 3) 'uint8'
|
||||
- qpos (14,) 'float64'
|
||||
- qvel (14,) 'float64'
|
||||
|
||||
action (14,) 'float64'
|
||||
"""
|
||||
|
||||
data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
|
||||
# truncate here to be consistent
|
||||
joint_traj = joint_traj[:-1]
|
||||
episode_replay = episode_replay[:-1]
|
||||
|
||||
# len(joint_traj) i.e. actions: max_timesteps
|
||||
# len(episode_replay) i.e. time steps: max_timesteps + 1
|
||||
max_timesteps = len(joint_traj)
|
||||
while joint_traj:
|
||||
action = joint_traj.pop(0)
|
||||
ts = episode_replay.pop(0)
|
||||
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
data_dict['/action'].append(action)
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
# HDF5
|
||||
t0 = time.time()
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
||||
root.attrs['sim'] = True
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in camera_names:
|
||||
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
# compression='gzip',compression_opts=2,)
|
||||
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
|
||||
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
|
||||
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
|
||||
action = root.create_dataset('action', (max_timesteps, 14))
|
||||
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs\n')
|
||||
|
||||
print(f'Saved to {dataset_dir}')
|
||||
print(f'Success: {np.sum(success)} / {len(success)}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
|
||||
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pyquaternion import Quaternion
|
||||
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
class BasePolicy:
|
||||
def __init__(self, inject_noise=False):
|
||||
self.inject_noise = inject_noise
|
||||
self.step_count = 0
|
||||
self.left_trajectory = None
|
||||
self.right_trajectory = None
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def interpolate(curr_waypoint, next_waypoint, t):
|
||||
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
|
||||
curr_xyz = curr_waypoint['xyz']
|
||||
curr_quat = curr_waypoint['quat']
|
||||
curr_grip = curr_waypoint['gripper']
|
||||
next_xyz = next_waypoint['xyz']
|
||||
next_quat = next_waypoint['quat']
|
||||
next_grip = next_waypoint['gripper']
|
||||
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
|
||||
quat = curr_quat + (next_quat - curr_quat) * t_frac
|
||||
gripper = curr_grip + (next_grip - curr_grip) * t_frac
|
||||
return xyz, quat, gripper
|
||||
|
||||
def __call__(self, ts):
|
||||
# generate trajectory at first timestep, then open-loop execution
|
||||
if self.step_count == 0:
|
||||
self.generate_trajectory(ts)
|
||||
|
||||
# obtain left and right waypoints
|
||||
if self.left_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_left_waypoint = self.left_trajectory.pop(0)
|
||||
next_left_waypoint = self.left_trajectory[0]
|
||||
|
||||
if self.right_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_right_waypoint = self.right_trajectory.pop(0)
|
||||
next_right_waypoint = self.right_trajectory[0]
|
||||
|
||||
# interpolate between waypoints to obtain current pose and gripper command
|
||||
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
|
||||
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
|
||||
|
||||
# Inject noise
|
||||
if self.inject_noise:
|
||||
scale = 0.01
|
||||
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
|
||||
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
|
||||
|
||||
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
|
||||
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
|
||||
|
||||
self.step_count += 1
|
||||
return np.concatenate([action_left, action_right])
|
||||
|
||||
|
||||
class PickAndTransferPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
box_info = np.array(ts_first.observation['env_state'])
|
||||
box_xyz = box_info[:3]
|
||||
box_quat = box_info[3:]
|
||||
# print(f"Generate trajectory for {box_xyz=}")
|
||||
|
||||
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.25])
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
|
||||
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
|
||||
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
|
||||
]
|
||||
|
||||
|
||||
class InsertionPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
peg_info = np.array(ts_first.observation['env_state'])[:7]
|
||||
peg_xyz = peg_info[:3]
|
||||
peg_quat = peg_info[3:]
|
||||
|
||||
socket_info = np.array(ts_first.observation['env_state'])[7:]
|
||||
socket_xyz = socket_info[:3]
|
||||
socket_quat = socket_info[3:]
|
||||
|
||||
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.15])
|
||||
lift_right = 0.00715
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
|
||||
]
|
||||
|
||||
|
||||
def test_policy(task_name):
|
||||
# example rolling out pick_and_transfer policy
|
||||
onscreen_render = True
|
||||
inject_noise = False
|
||||
|
||||
# setup the environment
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
env = make_ee_sim_env('sim_transfer_cube')
|
||||
elif 'sim_insertion' in task_name:
|
||||
env = make_ee_sim_env('sim_insertion')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for episode_idx in range(2):
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
policy = PickAndTransferPolicy(inject_noise)
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
if episode_return > 0:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_task_name = 'sim_transfer_cube_scripted'
|
||||
test_policy(test_task_name)
|
||||
|
||||
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import collections
|
||||
import matplotlib.pyplot as plt
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
def make_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7:7+6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7+6]
|
||||
|
||||
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
|
||||
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14)
|
||||
# arm action
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# gripper action
|
||||
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
|
||||
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
|
||||
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
|
||||
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
|
||||
action[6] = normalized_left_pos
|
||||
action[7+6] = normalized_right_pos
|
||||
return action
|
||||
|
||||
def test_sim_teleop():
|
||||
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
|
||||
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
|
||||
|
||||
# source of data
|
||||
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_left', init_node=True)
|
||||
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_right', init_node=False)
|
||||
|
||||
# setup the environment
|
||||
env = make_sim_env('sim_transfer_cube')
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
# setup plotting
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
for t in range(1000):
|
||||
action = get_action(master_bot_left, master_bot_right)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sim_teleop()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,575 @@
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
import tracemalloc
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import set_seed
|
||||
|
||||
|
||||
# 配置logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
# # 隐藏h5py的警告Creating converter from 7 to 5
|
||||
# logging.getLogger("h5py").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RmActEvaluator:
|
||||
def __init__(self, config, save_episode=True, num_rollouts=50):
|
||||
"""
|
||||
初始化Evaluator类
|
||||
|
||||
Args:
|
||||
config (dict): 配置字典
|
||||
checkpoint_name (str): 检查点名称
|
||||
save_episode (bool): 是否保存每个episode
|
||||
num_rollouts (int): 滚动次数
|
||||
"""
|
||||
self.config = config
|
||||
self._seed = config["seed"]
|
||||
self.robot_env = config["robot_env"]
|
||||
self.checkpoint_dir = config["checkpoint_dir"]
|
||||
self.checkpoint_name = config["checkpoint_name"]
|
||||
self.save_episode = save_episode
|
||||
self.num_rollouts = num_rollouts
|
||||
self.state_dim = config["state_dim"]
|
||||
self.real_robot = config["real_robot"]
|
||||
self.policy_class = config["policy_class"]
|
||||
self.onscreen_render = config["onscreen_render"]
|
||||
self.camera_names = config["camera_names"]
|
||||
self.max_timesteps = config["episode_len"]
|
||||
self.task_name = config["task_name"]
|
||||
self.temporal_agg = config["temporal_agg"]
|
||||
self.onscreen_cam = "angle"
|
||||
self.policy_config = config["policy_config"]
|
||||
self.vq = config["policy_config"]["vq"]
|
||||
# self.actuator_config = config["actuator_config"]
|
||||
# self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None
|
||||
self.stats = None
|
||||
self.env = None
|
||||
self.env_max_reward = 0
|
||||
|
||||
def _make_policy(self, policy_class, policy_config):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
|
||||
Args:
|
||||
policy_class (str): 策略类名称
|
||||
policy_config (dict): 策略配置字典
|
||||
|
||||
Returns:
|
||||
policy: 创建的策略对象
|
||||
"""
|
||||
if policy_class == "ACT":
|
||||
return ACTPolicy(policy_config)
|
||||
elif policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(policy_config)
|
||||
elif policy_class == "Diffusion":
|
||||
return DiffusionPolicy(policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {policy_class} is not implemented")
|
||||
|
||||
def load_policy_and_stats(self):
|
||||
"""
|
||||
加载策略和统计数据
|
||||
"""
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
|
||||
logging.info(f"Loading policy from: {checkpoint_path}")
|
||||
self.policy = self._make_policy(self.policy_class, self.policy_config)
|
||||
# 加载模型并设置为评估模式
|
||||
self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True))
|
||||
self.policy.cuda()
|
||||
self.policy.eval()
|
||||
|
||||
if self.vq:
|
||||
vq_dim = self.config["policy_config"]["vq_dim"]
|
||||
vq_class = self.config["policy_config"]["vq_class"]
|
||||
self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
|
||||
latent_model_checkpoint_path = os.path.join(
|
||||
self.checkpoint_dir, "latent_model_last.ckpt"
|
||||
)
|
||||
self.latent_model.deserialize(torch.load(latent_model_checkpoint_path))
|
||||
self.latent_model.eval()
|
||||
self.latent_model.cuda()
|
||||
logging.info(
|
||||
f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}"
|
||||
)
|
||||
else:
|
||||
logging.info(f"Loaded: {checkpoint_path}")
|
||||
|
||||
stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "rb") as f:
|
||||
self.stats = pickle.load(f)
|
||||
|
||||
def pre_process(self, state_qpos):
|
||||
"""
|
||||
预处理状态位置
|
||||
|
||||
Args:
|
||||
state_qpos (np.array): 状态位置数组
|
||||
|
||||
Returns:
|
||||
np.array: 预处理后的状态位置
|
||||
"""
|
||||
if self.policy_class == "Diffusion":
|
||||
return ((state_qpos + 1) / 2) * (
|
||||
self.stats["action_max"] - self.stats["action_min"]
|
||||
) + self.stats["action_min"]
|
||||
# 标准化处理,均值为 0,标准差为 1
|
||||
|
||||
return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"]
|
||||
|
||||
def post_process(self, action):
|
||||
"""
|
||||
后处理动作
|
||||
|
||||
Args:
|
||||
action (np.array): 动作数组
|
||||
|
||||
Returns:
|
||||
np.array: 后处理后的动作
|
||||
"""
|
||||
# 反标准化处理
|
||||
return action * self.stats["action_std"] + self.stats["action_mean"]
|
||||
|
||||
def get_image_torch(self, timestep, camera_names, random_crop_resize=False):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Args:
|
||||
timestep (object): 时间步对象
|
||||
camera_names (list): 相机名称列表
|
||||
random_crop_resize (bool): 是否随机裁剪和调整大小
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width)
|
||||
"""
|
||||
current_images = []
|
||||
for cam_name in camera_names:
|
||||
current_image = rearrange(
|
||||
timestep.observation["images"][cam_name], "h w c -> c h w"
|
||||
)
|
||||
current_images.append(current_image)
|
||||
current_image = np.stack(current_images, axis=0)
|
||||
current_image = (
|
||||
torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0)
|
||||
)
|
||||
|
||||
if random_crop_resize:
|
||||
logging.info("Random crop resize is used!")
|
||||
original_size = current_image.shape[-2:]
|
||||
ratio = 0.95
|
||||
current_image = current_image[
|
||||
...,
|
||||
int(original_size[0] * (1 - ratio) / 2) : int(
|
||||
original_size[0] * (1 + ratio) / 2
|
||||
),
|
||||
int(original_size[1] * (1 - ratio) / 2) : int(
|
||||
original_size[1] * (1 + ratio) / 2
|
||||
),
|
||||
]
|
||||
current_image = current_image.squeeze(0)
|
||||
resize_transform = transforms.Resize(original_size, antialias=True)
|
||||
current_image = resize_transform(current_image)
|
||||
current_image = current_image.unsqueeze(0)
|
||||
|
||||
return current_image
|
||||
|
||||
def load_environment(self):
|
||||
"""
|
||||
加载环境
|
||||
"""
|
||||
if self.real_robot:
|
||||
self.env = DeviceAloha(self.robot_env)
|
||||
self.env_max_reward = 0
|
||||
else:
|
||||
from sim_env import make_sim_env
|
||||
|
||||
self.env = make_sim_env(self.task_name)
|
||||
self.env_max_reward = self.env.task.max_reward
|
||||
|
||||
def get_auto_index(self, checkpoint_dir):
|
||||
max_idx = 1000
|
||||
for i in range(max_idx + 1):
|
||||
if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")):
|
||||
return i
|
||||
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")
|
||||
|
||||
def evaluate(self, checkpoint_name=None):
|
||||
"""
|
||||
评估策略
|
||||
|
||||
Returns:
|
||||
tuple: 成功率和平均回报
|
||||
"""
|
||||
if checkpoint_name is not None:
|
||||
self.checkpoint_name = checkpoint_name
|
||||
set_seed(self._seed) # np与torch的随机种子
|
||||
self.load_policy_and_stats()
|
||||
self.load_environment()
|
||||
|
||||
query_frequency = self.policy_config["num_queries"]
|
||||
|
||||
# 时间聚合时,每个时间步只有1个查询
|
||||
if self.temporal_agg:
|
||||
query_frequency = 1
|
||||
num_queries = self.policy_config["num_queries"]
|
||||
|
||||
# # 真实机器人时,基础延迟为13???
|
||||
# if self.real_robot:
|
||||
# BASE_DELAY = 13
|
||||
# # query_frequency -= BASE_DELAY
|
||||
|
||||
max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks
|
||||
episode_returns = []
|
||||
highest_rewards = []
|
||||
|
||||
for rollout_id in range(self.num_rollouts):
|
||||
|
||||
timestep = self.env.reset()
|
||||
|
||||
if self.onscreen_render:
|
||||
# TODO 画图
|
||||
pass
|
||||
if self.temporal_agg:
|
||||
all_time_actions = torch.zeros(
|
||||
[max_timesteps, max_timesteps + num_queries, self.state_dim]
|
||||
).cuda()
|
||||
qpos_history_raw = np.zeros((max_timesteps, self.state_dim))
|
||||
rewards = []
|
||||
|
||||
with torch.inference_mode():
|
||||
time_0 = time.time()
|
||||
DT = 1 / 30
|
||||
culmulated_delay = 0
|
||||
for t in range(max_timesteps):
|
||||
time_1 = time.time()
|
||||
if self.onscreen_render:
|
||||
# TODO 显示图像
|
||||
pass
|
||||
# process previous timestep to get qpos and image_list
|
||||
obs = timestep.observation
|
||||
qpos_numpy = np.array(obs["qpos"])
|
||||
qpos_history_raw[t] = qpos_numpy
|
||||
qpos = self.pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
|
||||
logging.info(f"t{t}")
|
||||
|
||||
if t % query_frequency == 0:
|
||||
current_image = self.get_image_torch(
|
||||
timestep,
|
||||
self.camera_names,
|
||||
random_crop_resize=(
|
||||
self.config["policy_class"] == "Diffusion"
|
||||
),
|
||||
)
|
||||
|
||||
if t == 0:
|
||||
# 网络预热
|
||||
for _ in range(10):
|
||||
self.policy(qpos, current_image)
|
||||
logging.info("Network warm up done")
|
||||
|
||||
if self.config["policy_class"] == "ACT":
|
||||
if t % query_frequency == 0:
|
||||
if self.vq:
|
||||
if rollout_id == 0:
|
||||
for _ in range(10):
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
logging.info(
|
||||
torch.nonzero(vq_sample[0])[:, 1]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
all_actions = self.policy(
|
||||
qpos, current_image, vq_sample=vq_sample
|
||||
)
|
||||
else:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
if self.temporal_agg:
|
||||
all_time_actions[[t], t : t + num_queries] = all_actions
|
||||
actions_for_curr_step = all_time_actions[:, t]
|
||||
actions_populated = torch.all(
|
||||
actions_for_curr_step != 0, axis=1
|
||||
)
|
||||
actions_for_curr_step = actions_for_curr_step[
|
||||
actions_populated
|
||||
]
|
||||
k = 0.01
|
||||
exp_weights = np.exp(
|
||||
-k * np.arange(len(actions_for_curr_step))
|
||||
)
|
||||
exp_weights = exp_weights / exp_weights.sum()
|
||||
exp_weights = (
|
||||
torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
)
|
||||
raw_action = (actions_for_curr_step * exp_weights).sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
else:
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "Diffusion":
|
||||
if t % query_frequency == 0:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "CNNMLP":
|
||||
raw_action = self.policy(qpos, current_image)
|
||||
all_actions = raw_action.unsqueeze(0)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
### post-process actions
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
action = self.post_process(raw_action)
|
||||
|
||||
### step the environment
|
||||
if self.real_robot:
|
||||
logging.info(f" action = {action}")
|
||||
timestep = self.env.step(action)
|
||||
|
||||
rewards.append(timestep.reward)
|
||||
duration = time.time() - time_1
|
||||
sleep_time = max(0, DT - duration)
|
||||
time.sleep(sleep_time)
|
||||
if duration >= DT:
|
||||
culmulated_delay += duration - DT
|
||||
logging.warning(
|
||||
f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s"
|
||||
)
|
||||
|
||||
logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}")
|
||||
plt.close()
|
||||
|
||||
if self.real_robot:
|
||||
log_id = self.get_auto_index(self.checkpoint_dir)
|
||||
np.save(
|
||||
os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"),
|
||||
qpos_history_raw,
|
||||
)
|
||||
plt.figure(figsize=(10, 20))
|
||||
for i in range(self.state_dim):
|
||||
plt.subplot(self.state_dim, 1, i + 1)
|
||||
plt.plot(qpos_history_raw[:, i])
|
||||
if i != self.state_dim - 1:
|
||||
plt.xticks([])
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png"))
|
||||
plt.close()
|
||||
|
||||
rewards = np.array(rewards)
|
||||
episode_return = np.sum(rewards[rewards != None])
|
||||
episode_returns.append(episode_return)
|
||||
episode_highest_reward = np.max(rewards)
|
||||
highest_rewards.append(episode_highest_reward)
|
||||
logging.info(
|
||||
f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}"
|
||||
)
|
||||
|
||||
success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward)
|
||||
avg_return = np.mean(episode_returns)
|
||||
summary_str = (
|
||||
f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n"
|
||||
)
|
||||
for r in range(self.env_max_reward + 1):
|
||||
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
|
||||
more_or_equal_r_rate = more_or_equal_r / self.num_rollouts
|
||||
summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n"
|
||||
|
||||
logging.info(summary_str)
|
||||
|
||||
result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt"
|
||||
with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f:
|
||||
f.write(summary_str)
|
||||
f.write(repr(episode_returns))
|
||||
f.write("\n\n")
|
||||
f.write(repr(highest_rewards))
|
||||
|
||||
return success_rate, avg_return
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_axis = aloha_config["arm_axis"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
self.bottom_image, _, _, _ = self.camera_bottom.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.left_image, _, _, _ = self.camera_left.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.right_image, _, _, _ = self.camera_right.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
return {
|
||||
"cam_high": self.top_image,
|
||||
"cam_low": self.bottom_image,
|
||||
"cam_left": self.left_image,
|
||||
"cam_right": self.right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1])
|
||||
self.arm_left.set_gripper_position(target_angle[self.arm_axis])
|
||||
self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
# config = yaml.safe_load(f)
|
||||
# aloha_config = config["robot_env"]
|
||||
# device = DeviceAloha(aloha_config)
|
||||
# device.reset()
|
||||
# while True:
|
||||
# init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
# time_step = time.time()
|
||||
# timestep = device.step(init_angle)
|
||||
# logging.info(f"Time: {time.time() - time_step}")
|
||||
# obs = timestep.observation
|
||||
|
||||
with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
# logging.info(f"Config: {config}")
|
||||
evaluator = RmActEvaluator(config)
|
||||
success_rate, avg_return = evaluator.evaluate()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from typing import Dict, List
|
||||
import torch.nn.functional as F
|
||||
from .position_encoding import build_position_encoding
|
||||
from torchvision.models import ResNet18_Weights
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from shadow_act.utils.misc import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
train_backbone: bool,
|
||||
num_channels: int,
|
||||
return_interm_layers: bool,
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool,
|
||||
):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# backbone = getattr(torchvision.models, name)(
|
||||
# replace_stride_with_dilation=[False, False, dilation],
|
||||
# pretrained=is_main_process(),
|
||||
# norm_layer=FrozenBatchNorm2d,
|
||||
# ) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
):
|
||||
|
||||
position_embedding = build_position_encoding(
|
||||
hidden_dim=hidden_dim, position_embedding_type=position_embedding_type
|
||||
)
|
||||
train_backbone = lr_backbone > 0
|
||||
return_interm_layers = masks
|
||||
backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR model and criterion classes.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from shadow_act.models.transformer import Transformer
|
||||
from .backbone import build_backbone
|
||||
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
||||
)
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
|
||||
self.state_dim, self.action_dim = state_dim, action_dim
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(
|
||||
backbones[0].num_channels, hidden_dim, kernel_size=1
|
||||
)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project qpos to embedding
|
||||
if self.vq:
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
|
||||
else:
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
if self.vq:
|
||||
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
|
||||
else:
|
||||
self.latent_out_proj = nn.Linear(
|
||||
self.latent_dim, hidden_dim
|
||||
) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
|
||||
bs, _ = qpos.shape
|
||||
if self.encoder is None:
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
|
||||
qpos.device
|
||||
)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
probs = binaries = mu = logvar = None
|
||||
else:
|
||||
# cvae encoder
|
||||
is_training = actions is not None # train or val
|
||||
### Obtain latent z from action sequence
|
||||
if is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(
|
||||
actions
|
||||
) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
|
||||
bs, 1, 1
|
||||
) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(
|
||||
1, 0, 2
|
||||
) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
cls_joint_is_pad = torch.full((bs, 2), False).to(
|
||||
qpos.device
|
||||
) # False: not a padding
|
||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(
|
||||
encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
|
||||
)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
|
||||
if self.vq:
|
||||
logits = latent_info.reshape(
|
||||
[*latent_info.shape[:-1], self.vq_class, self.vq_dim]
|
||||
)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
binaries = (
|
||||
F.one_hot(
|
||||
torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(
|
||||
-1
|
||||
),
|
||||
self.vq_dim,
|
||||
)
|
||||
.view(-1, self.vq_class, self.vq_dim)
|
||||
.float()
|
||||
)
|
||||
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
|
||||
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
|
||||
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
|
||||
latent_input = self.latent_out_proj(straigt_through)
|
||||
mu = logvar = None
|
||||
else:
|
||||
probs = binaries = None
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
else:
|
||||
mu = logvar = binaries = probs = None
|
||||
if self.vq:
|
||||
latent_input = self.latent_out_proj(
|
||||
vq_sample.view(-1, self.vq_class * self.vq_dim)
|
||||
)
|
||||
else:
|
||||
latent_sample = torch.zeros(
|
||||
[bs, self.latent_dim], dtype=torch.float32
|
||||
).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
return latent_input, probs, binaries, mu, logvar
|
||||
|
||||
def forward(
|
||||
self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None
|
||||
):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
|
||||
latent_input, probs, binaries, mu, logvar = self.encode(
|
||||
qpos, actions, is_pad, vq_sample
|
||||
)
|
||||
|
||||
# cvae decoder
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# TODO: fix this error
|
||||
features, pos = self.backbones[0](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(
|
||||
transformer_input, None, self.query_embed.weight, self.pos.weight
|
||||
)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
for backbone in backbones:
|
||||
down_proj = nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||
nn.Conv2d(128, 64, kernel_size=5),
|
||||
nn.Conv2d(64, 32, kernel_size=5),
|
||||
)
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||
self.mlp = mlp(
|
||||
input_dim=mlp_in_dim,
|
||||
hidden_dim=1024,
|
||||
output_dim=self.action_dim,
|
||||
hidden_depth=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0] # not used
|
||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||
# flatten everything
|
||||
flattened_features = []
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for i in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(
|
||||
hidden_dim, # 256
|
||||
dropout, # 0.1
|
||||
nheads, # 8
|
||||
dim_feedforward,
|
||||
num_encoder_layers, # 4 # TODO shared with VAE decoder
|
||||
normalize_before, # False
|
||||
):
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build_vae(
|
||||
hidden_dim,
|
||||
state_dim,
|
||||
position_embedding_type,
|
||||
lr_backbone,
|
||||
masks,
|
||||
backbone,
|
||||
dilation,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
dec_layers,
|
||||
pre_norm,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
no_encoder,
|
||||
):
|
||||
# TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
)
|
||||
|
||||
if no_encoder:
|
||||
encoder = None
|
||||
else:
|
||||
encoder = build_encoder(
|
||||
hidden_dim,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
pre_norm,
|
||||
)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
|
||||
# TODO
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torch
|
||||
|
||||
DROPOUT_RATE = 0.1 # 定义 dropout 率
|
||||
|
||||
# 定义一个因果变压器块
|
||||
class Causal_Transformer_Block(nn.Module):
|
||||
def __init__(self, seq_len, latent_dim, num_head) -> None:
|
||||
"""
|
||||
初始化因果变压器块
|
||||
|
||||
Args:
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int): 潜在维度
|
||||
num_head (int): 注意力头的数量
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.latent_dim = latent_dim
|
||||
self.ln_1 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True) # 多头注意力机制
|
||||
self.ln_2 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(latent_dim, 4 * latent_dim), # 全连接层
|
||||
nn.GELU(), # GELU 激活函数
|
||||
nn.Linear(4 * latent_dim, latent_dim), # 全连接层
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
)
|
||||
|
||||
# self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()) # 注册注意力掩码
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
# 创建上三角掩码,防止信息泄露
|
||||
attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1)
|
||||
x = self.ln_1(x) # 层归一化
|
||||
x = x + self.attn(x, x, x, attn_mask=attn_mask)[0] # 加上注意力输出
|
||||
x = self.ln_2(x) # 层归一化
|
||||
x = x + self.mlp(x) # 加上 MLP 输出
|
||||
|
||||
return x
|
||||
|
||||
# 使用自注意力机制而不是 RNN 来建模潜在空间序列
|
||||
class Latent_Model_Transformer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None:
|
||||
"""
|
||||
初始化潜在模型变压器
|
||||
|
||||
Args:
|
||||
input_dim (int): 输入维度
|
||||
output_dim (int): 输出维度
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int, optional): 潜在维度,默认值为 256
|
||||
num_head (int, optional): 注意力头的数量,默认值为 8
|
||||
num_layer (int, optional): 变压器层的数量,默认值为 3
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.seq_len = seq_len
|
||||
self.latent_dim = latent_dim
|
||||
self.num_head = num_head
|
||||
self.num_layer = num_layer
|
||||
self.input_layer = nn.Linear(input_dim, latent_dim) # 输入层
|
||||
self.weight_pos_embed = nn.Embedding(seq_len, latent_dim) # 位置嵌入
|
||||
self.attention_blocks = nn.Sequential(
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
*[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)], # 多个因果变压器块
|
||||
nn.LayerNorm(latent_dim) # 层归一化
|
||||
)
|
||||
self.output_layer = nn.Linear(latent_dim, output_dim) # 输出层
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
x = self.input_layer(x) # 输入层
|
||||
x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device)) # 加上位置嵌入
|
||||
x = self.attention_blocks(x) # 通过注意力块
|
||||
logits = self.output_layer(x) # 输出层
|
||||
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, n, temperature=0.1, x=None):
|
||||
"""
|
||||
生成序列
|
||||
|
||||
Args:
|
||||
n (int): 生成序列的数量
|
||||
temperature (float, optional): 采样温度,默认值为 0.1
|
||||
x (torch.Tensor, optional): 初始输入张量,默认值为 None
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 生成的序列
|
||||
"""
|
||||
if x is None:
|
||||
x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device) # 初始化输入
|
||||
for i in range(self.seq_len):
|
||||
logits = self.forward(x)[:, -1] # 获取最后一个时间步的输出
|
||||
probs = torch.softmax(logits / temperature, dim=-1) # 计算概率分布
|
||||
samples = torch.multinomial(probs, num_samples=1)[..., 0] # 从概率分布中采样
|
||||
samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float() # 转为 one-hot 编码
|
||||
x = torch.cat([x, samples_one_hot[:, None, :]], dim=1) # 将新采样的结果添加到输入中
|
||||
|
||||
return x[:, 1:, :] # 返回生成的序列(去掉初始的零输入)
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from shadow_act.utils.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(hidden_dim, position_embedding_type):
|
||||
N_steps = hidden_dim // 2
|
||||
if position_embedding_type in ('v2', 'sine'):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif position_embedding_type in ('v3', 'learned'):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {position_embedding_type}")
|
||||
|
||||
return position_embedding
|
||||
@@ -0,0 +1,424 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(
|
||||
encoder_layer, num_encoder_layers, encoder_norm
|
||||
)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer,
|
||||
num_decoder_layers,
|
||||
decoder_norm,
|
||||
return_intermediate=return_intermediate_dec,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
|
||||
1, bs, 1
|
||||
) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(
|
||||
tgt,
|
||||
memory,
|
||||
memory_key_padding_mask=mask,
|
||||
pos=pos_embed,
|
||||
query_pos=query_embed,
|
||||
)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
pos=pos,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
):
|
||||
return Transformer(
|
||||
d_model=hidden_dim,
|
||||
dropout=dropout,
|
||||
nhead=nheads,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=enc_layers,
|
||||
num_decoder_layers=dec_layers,
|
||||
normalize_before=pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,522 @@
|
||||
import torch
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from shadow_act.models.detr_vae import build_vae, build_cnnmlp
|
||||
|
||||
# from diffusers.training_utils import EMAModel
|
||||
# from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||
# from robomimic.algo.diffusion_policy import replace_bn_with_gn, ConditionalUnet1D
|
||||
|
||||
# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: 重构DiffusionPolicy类
|
||||
class DiffusionPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化DiffusionPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.camera_names = args_override["camera_names"]
|
||||
self.observation_horizon = args_override["observation_horizon"]
|
||||
self.action_horizon = args_override["action_horizon"]
|
||||
self.prediction_horizon = args_override["prediction_horizon"]
|
||||
self.num_inference_timesteps = args_override["num_inference_timesteps"]
|
||||
self.ema_power = args_override["ema_power"]
|
||||
self.lr = args_override["lr"]
|
||||
self.weight_decay = 0
|
||||
|
||||
self.num_kp = 32
|
||||
self.feature_dimension = 64
|
||||
self.ac_dim = args_override["action_dim"]
|
||||
self.obs_dim = self.feature_dimension * len(self.camera_names) + 14
|
||||
|
||||
backbones = []
|
||||
pools = []
|
||||
linears = []
|
||||
for _ in self.camera_names:
|
||||
backbones.append(
|
||||
ResNet18Conv(input_channel=3, pretrained=False, input_coord_conv=False)
|
||||
)
|
||||
pools.append(
|
||||
SpatialSoftmax(
|
||||
input_shape=[512, 15, 20],
|
||||
num_kp=self.num_kp,
|
||||
temperature=1.0,
|
||||
learnable_temperature=False,
|
||||
noise_std=0.0,
|
||||
)
|
||||
)
|
||||
linears.append(
|
||||
torch.nn.Linear(int(np.prod([self.num_kp, 2])), self.feature_dimension)
|
||||
)
|
||||
backbones = nn.ModuleList(backbones)
|
||||
pools = nn.ModuleList(pools)
|
||||
linears = nn.ModuleList(linears)
|
||||
|
||||
backbones = replace_bn_with_gn(backbones)
|
||||
|
||||
noise_pred_net = ConditionalUnet1D(
|
||||
input_dim=self.ac_dim,
|
||||
global_cond_dim=self.obs_dim * self.observation_horizon,
|
||||
)
|
||||
|
||||
nets = nn.ModuleDict(
|
||||
{
|
||||
"policy": nn.ModuleDict(
|
||||
{
|
||||
"backbones": backbones,
|
||||
"pools": pools,
|
||||
"linears": linears,
|
||||
"noise_pred_net": noise_pred_net,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
nets = nets.float().cuda()
|
||||
ENABLE_EMA = True
|
||||
if ENABLE_EMA:
|
||||
ema = EMAModel(model=nets, power=self.ema_power)
|
||||
else:
|
||||
ema = None
|
||||
self.nets = nets
|
||||
self.ema = ema
|
||||
|
||||
# 设置噪声调度器
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=50,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in self.parameters())
|
||||
logger.info("number of parameters: %.2fM", n_parameters / 1e6)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.nets.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
||||
)
|
||||
return optimizer
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
B = qpos.shape[0]
|
||||
if actions is not None: # 训练时
|
||||
nets = self.nets
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 为动作添加噪声
|
||||
noise = torch.randn(actions.shape, device=obs_cond.device)
|
||||
|
||||
# 为每个数据点采样一个扩散迭代
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(B,),
|
||||
device=obs_cond.device,
|
||||
).long()
|
||||
|
||||
# 根据每个扩散迭代的噪声幅度向干净动作添加噪声
|
||||
noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
|
||||
|
||||
# 预测噪声残差
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
noisy_actions, timesteps, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# L2损失
|
||||
all_l2 = F.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (all_l2 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l2_loss"] = loss
|
||||
loss_dict["loss"] = loss
|
||||
|
||||
if self.training and self.ema is not None:
|
||||
self.ema.step(nets)
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
To = self.observation_horizon
|
||||
Ta = self.action_horizon
|
||||
Tp = self.prediction_horizon
|
||||
action_dim = self.ac_dim
|
||||
|
||||
nets = self.nets
|
||||
if self.ema is not None:
|
||||
nets = self.ema.averaged_model
|
||||
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 从高斯噪声初始化动作
|
||||
noisy_action = torch.randn((B, Tp, action_dim), device=obs_cond.device)
|
||||
naction = noisy_action
|
||||
|
||||
# 初始化调度器
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# 预测噪声
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
sample=naction, timestep=k, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# 逆扩散步骤(去除噪声)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return {
|
||||
"nets": self.nets.state_dict(),
|
||||
"ema": (
|
||||
self.ema.averaged_model.state_dict() if self.ema is not None else None
|
||||
),
|
||||
}
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
status = self.nets.load_state_dict(model_dict["nets"])
|
||||
logger.info("Loaded model")
|
||||
if model_dict.get("ema", None) is not None:
|
||||
logger.info("Loaded EMA")
|
||||
status_ema = self.ema.averaged_model.load_state_dict(model_dict["ema"])
|
||||
status = [status, status_ema]
|
||||
return status
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module):
|
||||
def __init__(self, act_config):
|
||||
"""
|
||||
初始化ACTPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
lr_backbone = act_config["lr_backbone"]
|
||||
vq = act_config["vq"]
|
||||
lr = act_config["lr"]
|
||||
weight_decay = act_config["weight_decay"]
|
||||
|
||||
model = build_vae(
|
||||
act_config["hidden_dim"],
|
||||
act_config["state_dim"],
|
||||
act_config["position_embedding"],
|
||||
lr_backbone,
|
||||
act_config["masks"],
|
||||
act_config["backbone"],
|
||||
act_config["dilation"],
|
||||
act_config["dropout"],
|
||||
act_config["nheads"],
|
||||
act_config["dim_feedforward"],
|
||||
act_config["enc_layers"],
|
||||
act_config["dec_layers"],
|
||||
act_config["pre_norm"],
|
||||
act_config["num_queries"],
|
||||
act_config["camera_names"],
|
||||
vq,
|
||||
act_config["vq_class"],
|
||||
act_config["vq_dim"],
|
||||
act_config["action_dim"],
|
||||
act_config["no_encoder"],
|
||||
)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
self.model = model # CVAE解码器
|
||||
self.kl_weight = act_config["kl_weight"]
|
||||
self.vq = vq
|
||||
logger.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 角度张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
vq_sample (torch.Tensor, optional): VQ样本
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
loss_dict = dict()
|
||||
a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model(
|
||||
qpos, image, env_state, actions, is_pad
|
||||
)
|
||||
if self.vq or self.model.encoder is None:
|
||||
total_kld = [torch.tensor(0.0)]
|
||||
else:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
if self.vq:
|
||||
loss_dict["vq_discrepancy"] = F.l1_loss(
|
||||
probs, binaries, reduction="mean"
|
||||
)
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict["l1"] = l1
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat, _, (_, _), _, _ = self.model(
|
||||
qpos, image, env_state, vq_sample=vq_sample
|
||||
) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def vq_encode(self, qpos, actions, is_pad):
|
||||
"""
|
||||
VQ编码
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
actions (torch.Tensor): 动作张量
|
||||
is_pad (torch.Tensor): 填充张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 二进制编码
|
||||
"""
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
_, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad)
|
||||
|
||||
return binaries
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return self.state_dict()
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
return self.load_state_dict(model_dict)
|
||||
|
||||
|
||||
class CNNMLPPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化CNNMLPPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
# parser = argparse.ArgumentParser(
|
||||
# "DETR training and evaluation script", parents=[get_args_parser()]
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
# for k, v in args_override.items():
|
||||
# setattr(args, k, v)
|
||||
|
||||
model = build_cnnmlp(args_override)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": args_override.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.model = model # 解码器
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=args_override.lr, weight_decay=args_override.weight_decay
|
||||
)
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, 0]
|
||||
a_hat = self.model(qpos, image, env_state, actions)
|
||||
mse = F.mse_loss(actions, a_hat)
|
||||
loss_dict = dict()
|
||||
loss_dict["mse"] = mse
|
||||
loss_dict["loss"] = loss_dict["mse"]
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat = self.model(qpos, image, env_state) # 无动作,从先验中采样
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
"""
|
||||
计算KL散度
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): 均值张量
|
||||
logvar (torch.Tensor): 对数方差张量
|
||||
|
||||
Returns:
|
||||
tuple: 总KL散度,维度KL散度,均值KL散度
|
||||
"""
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
@@ -0,0 +1,245 @@
|
||||
import os
|
||||
import yaml
|
||||
import pickle
|
||||
import torch
|
||||
# import wandb
|
||||
import logging
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
from itertools import repeat
|
||||
from shadow_act.utils.utils import (
|
||||
set_seed,
|
||||
load_data,
|
||||
compute_dict_mean,
|
||||
)
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.eval.rm_act_eval import RmActEvaluator
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
class RmActTrainer:
|
||||
def __init__(self, config):
|
||||
"""
|
||||
初始化训练器,设置随机种子,加载数据,保存数据统计信息。
|
||||
"""
|
||||
self._config = config
|
||||
self._num_steps = config["num_steps"]
|
||||
self._ckpt_dir = config["checkpoint_dir"]
|
||||
self._state_dim = config["state_dim"]
|
||||
self._real_robot = config["real_robot"]
|
||||
self._policy_class = config["policy_class"]
|
||||
self._onscreen_render = config["onscreen_render"]
|
||||
self._policy_config = config["policy_config"]
|
||||
self._camera_names = config["camera_names"]
|
||||
self._max_timesteps = config["episode_len"]
|
||||
self._task_name = config["task_name"]
|
||||
self._temporal_agg = config["temporal_agg"]
|
||||
self._onscreen_cam = "angle"
|
||||
self._vq = config["policy_config"]["vq"]
|
||||
self._batch_size = config["batch_size"]
|
||||
|
||||
self._seed = config["seed"]
|
||||
self._eval_every = config["eval_every"]
|
||||
self._validate_every = config["validate_every"]
|
||||
self._save_every = config["save_every"]
|
||||
self._load_pretrain = config["load_pretrain"]
|
||||
self._resume_ckpt_path = config["resume_ckpt_path"]
|
||||
|
||||
if config["name_filter"] is None:
|
||||
name_filter = lambda n : True
|
||||
else:
|
||||
name_filter = config["name_filter"]
|
||||
|
||||
self._eval = RmActEvaluator(config, True, 50)
|
||||
# 加载数据
|
||||
self._train_dataloader, self._val_dataloader, self._stats, _ = load_data(
|
||||
config["dataset_dir"],
|
||||
name_filter,
|
||||
self._camera_names,
|
||||
self._batch_size,
|
||||
self._batch_size,
|
||||
config["chunk_size"],
|
||||
config["skip_mirrored_data"],
|
||||
self._load_pretrain,
|
||||
self._policy_class,
|
||||
config["stats_dir"],
|
||||
config["sample_weights"],
|
||||
config["train_ratio"],
|
||||
)
|
||||
# 保存数据统计信息
|
||||
stats_path = os.path.join(self._ckpt_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "wb") as f:
|
||||
pickle.dump(self._stats, f)
|
||||
expr_name = self._ckpt_dir.split("/")[-1]
|
||||
|
||||
# wandb.init(
|
||||
# project="train_rm_aloha", reinit=True, entity="train_rm_aloha", name=expr_name
|
||||
# )
|
||||
|
||||
|
||||
def _make_policy(self):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
"""
|
||||
if self._policy_class == "ACT":
|
||||
return ACTPolicy(self._policy_config)
|
||||
elif self._policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(self._policy_config)
|
||||
elif self._policy_class == "Diffusion":
|
||||
return DiffusionPolicy(self._policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {self._policy_class} is not implemented")
|
||||
|
||||
def _make_optimizer(self):
|
||||
"""
|
||||
根据策略类创建优化器
|
||||
"""
|
||||
if self._policy_class in ["ACT", "CNNMLP", "Diffusion"]:
|
||||
return self._policy.configure_optimizers()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward_pass(self, data):
|
||||
"""
|
||||
前向传播,计算损失
|
||||
"""
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
try:
|
||||
image_data, qpos_data, action_data, is_pad = (
|
||||
image_data.cuda(),
|
||||
qpos_data.cuda(),
|
||||
action_data.cuda(),
|
||||
is_pad.cuda(),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.error(f"CUDA error: {e}")
|
||||
raise
|
||||
return self._policy(qpos_data, image_data, action_data, is_pad)
|
||||
|
||||
def _repeater(self):
|
||||
"""
|
||||
数据加载器的重复器,生成数据
|
||||
"""
|
||||
epoch = 0
|
||||
for loader in repeat(self._train_dataloader):
|
||||
for data in loader:
|
||||
yield data
|
||||
logging.info(f"Epoch {epoch} done")
|
||||
epoch += 1
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
训练模型,保存最佳模型
|
||||
"""
|
||||
set_seed(self._seed)
|
||||
self._policy = self._make_policy()
|
||||
min_val_loss = np.inf
|
||||
best_ckpt_info = None
|
||||
|
||||
# 加载预训练模型
|
||||
if self._load_pretrain:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(
|
||||
torch.load(
|
||||
os.path.join(
|
||||
"/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all",
|
||||
"policy_step_50000_seed_0.ckpt",
|
||||
)
|
||||
)
|
||||
)
|
||||
logging.info(f"loaded! {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Pretrain model not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading pretrain model: {e}")
|
||||
|
||||
# 恢复检查点
|
||||
if self._resume_ckpt_path is not None:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(torch.load(self._resume_ckpt_path))
|
||||
logging.info(f"Resume policy from: {self._resume_ckpt_path}, Status: {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Checkpoint not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading checkpoint: {e}")
|
||||
|
||||
self._policy.cuda()
|
||||
|
||||
self._optimizer = self._make_optimizer()
|
||||
train_dataloader = self._repeater() # 重复器
|
||||
|
||||
for step in tqdm(range(self._num_steps + 1)):
|
||||
# 验证模型
|
||||
if step % self._validate_every != 0:
|
||||
continue
|
||||
logging.info("validating")
|
||||
with torch.inference_mode():
|
||||
self._policy.eval()
|
||||
validation_dicts = []
|
||||
for batch_idx, data in enumerate(self._val_dataloader):
|
||||
forward_dict = self._forward_pass(data) # forward_dict = {"loss": loss, "kl": kl, "mse": mse}
|
||||
validation_dicts.append(forward_dict)
|
||||
if batch_idx > 50: # 限制验证批次数 TODO 确定批次关联
|
||||
break
|
||||
|
||||
validation_summary = compute_dict_mean(validation_dicts)
|
||||
epoch_val_loss = validation_summary["loss"]
|
||||
if epoch_val_loss < min_val_loss:
|
||||
min_val_loss = epoch_val_loss
|
||||
best_ckpt_info = (
|
||||
step,
|
||||
min_val_loss,
|
||||
deepcopy(self._policy.serialize()),
|
||||
)
|
||||
|
||||
# wandb记录验证结果
|
||||
# for k in list(validation_summary.keys()):
|
||||
# validation_summary[f"val_{k}"] = validation_summary.pop(k)
|
||||
|
||||
# wandb.log(validation_summary, step=step)
|
||||
logging.info(f"Val loss: {epoch_val_loss:.5f}")
|
||||
summary_string = " ".join(f"{k}: {v.item():.3f}" for k, v in validation_summary.items())
|
||||
logging.info(summary_string)
|
||||
|
||||
# 评估模型
|
||||
# if (step > 0) and (step % self._eval_every == 0):
|
||||
# ckpt_name = f"policy_step_{step}_seed_{self._seed}.ckpt"
|
||||
# ckpt_path = os.path.join(self._ckpt_dir, ckpt_name)
|
||||
# torch.save(self._policy.serialize(), ckpt_path)
|
||||
# success, _ = self._eval.evaluate(ckpt_name)
|
||||
# wandb.log({"success": success}, step=step)
|
||||
|
||||
# 训练模型
|
||||
self._policy.train()
|
||||
self._optimizer.zero_grad()
|
||||
data = next(train_dataloader)
|
||||
forward_dict = self._forward_pass(data)
|
||||
loss = forward_dict["loss"]
|
||||
loss.backward()
|
||||
self._optimizer.step()
|
||||
# wandb.log(forward_dict, step=step)
|
||||
|
||||
# 保存检查点
|
||||
if step % self._save_every == 0:
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{step}_seed_{self._seed}.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
# 保存最后的模型
|
||||
ckpt_path = os.path.join(self._ckpt_dir, "policy_last.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
best_step, min_val_loss, best_state_dict = best_ckpt_info
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{best_step}_seed_{self._seed}.ckpt")
|
||||
torch.save(best_state_dict, ckpt_path)
|
||||
logging.info(f"Training finished:\nSeed {self._seed}, val loss {min_val_loss:.6f} at step {best_step}")
|
||||
|
||||
return best_ckpt_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("/home/rm/aloha/shadow_rm_act/config/config.yaml") as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
trainer = RmActTrainer(config)
|
||||
trainer.train()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Utilities for bounding box manipulation and GIoU.
|
||||
"""
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
def masks_to_boxes(masks):
|
||||
"""Compute the bounding boxes around the provided masks
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||
"""
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device)
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
|
||||
x_mask = (masks * x.unsqueeze(0))
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
y_mask = (masks * y.unsqueeze(0))
|
||||
y_max = y_mask.flatten(1).max(-1)[0]
|
||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Plotting utilities to visualize training logs.
|
||||
"""
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
|
||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||
'''
|
||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||
|
||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||
|
||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||
- solid lines are training results, dashed lines are test results.
|
||||
|
||||
'''
|
||||
func_name = "plot_utils.py::plot_logs"
|
||||
|
||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||
# convert single Path to list to avoid 'not iterable' error
|
||||
|
||||
if not isinstance(logs, list):
|
||||
if isinstance(logs, PurePath):
|
||||
logs = [logs]
|
||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||
else:
|
||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||
|
||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||
for i, dir in enumerate(logs):
|
||||
if not isinstance(dir, PurePath):
|
||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||
if not dir.exists():
|
||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||
# verify log_name exists
|
||||
fn = Path(dir / log_name)
|
||||
if not fn.exists():
|
||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||
print(f"--> full path of missing log file: {fn}")
|
||||
return
|
||||
|
||||
# load log file(s) and plot
|
||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||
|
||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||
|
||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||
for j, field in enumerate(fields):
|
||||
if field == 'mAP':
|
||||
coco_eval = pd.DataFrame(
|
||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||
).ewm(com=ewm_col).mean()
|
||||
axs[j].plot(coco_eval, c=color)
|
||||
else:
|
||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||
y=[f'train_{field}', f'test_{field}'],
|
||||
ax=axs[j],
|
||||
color=[color] * 2,
|
||||
style=['-', '--']
|
||||
)
|
||||
for ax, field in zip(axs, fields):
|
||||
ax.legend([Path(p).name for p in logs])
|
||||
ax.set_title(field)
|
||||
|
||||
|
||||
def plot_precision_recall(files, naming_scheme='iter'):
|
||||
if naming_scheme == 'exp_id':
|
||||
# name becomes exp_id
|
||||
names = [f.parts[-3] for f in files]
|
||||
elif naming_scheme == 'iter':
|
||||
names = [f.stem for f in files]
|
||||
else:
|
||||
raise ValueError(f'not supported {naming_scheme}')
|
||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||
data = torch.load(f)
|
||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||
precision = data['precision']
|
||||
recall = data['params'].recThrs
|
||||
scores = data['scores']
|
||||
# take precision for all classes, all areas and 100 detections
|
||||
precision = precision[0, :, :, 0, -1].mean(1)
|
||||
scores = scores[0, :, :, 0, -1].mean(1)
|
||||
prec = precision.mean()
|
||||
rec = data['recall'][0, :, 0, -1].mean()
|
||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||
f'score={scores.mean():0.3f}, ' +
|
||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||
)
|
||||
axs[0].plot(recall, precision, c=color)
|
||||
axs[1].plot(recall, scores, c=color)
|
||||
|
||||
axs[0].set_title('Precision / Recall')
|
||||
axs[0].legend(names)
|
||||
axs[1].set_title('Scores / Recall')
|
||||
axs[1].legend(names)
|
||||
return fig, axs
|
||||
@@ -0,0 +1,499 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import h5py
|
||||
import pickle
|
||||
import fnmatch
|
||||
import cv2
|
||||
from time import time
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
|
||||
|
||||
def flatten_list(l):
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
class EpisodicDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
episode_ids,
|
||||
episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
):
|
||||
super(EpisodicDataset).__init__()
|
||||
self.episode_ids = episode_ids
|
||||
self.dataset_path_list = dataset_path_list
|
||||
self.camera_names = camera_names
|
||||
self.norm_stats = norm_stats
|
||||
self.episode_len = episode_len
|
||||
self.chunk_size = chunk_size
|
||||
self.cumulative_len = np.cumsum(self.episode_len)
|
||||
self.max_episode_len = max(episode_len)
|
||||
self.policy_class = policy_class
|
||||
if self.policy_class == "Diffusion":
|
||||
self.augment_images = True
|
||||
else:
|
||||
self.augment_images = False
|
||||
self.transformations = None
|
||||
self.__getitem__(0) # initialize self.is_sim and self.transformations
|
||||
self.is_sim = False
|
||||
|
||||
# def __len__(self):
|
||||
# return sum(self.episode_len)
|
||||
|
||||
def _locate_transition(self, index):
|
||||
assert index < self.cumulative_len[-1]
|
||||
episode_index = np.argmax(
|
||||
self.cumulative_len > index
|
||||
) # argmax returns first True index
|
||||
start_ts = index - (
|
||||
self.cumulative_len[episode_index] - self.episode_len[episode_index]
|
||||
)
|
||||
episode_id = self.episode_ids[episode_index]
|
||||
return episode_id, start_ts
|
||||
|
||||
def __getitem__(self, index):
|
||||
episode_id, start_ts = self._locate_transition(index)
|
||||
dataset_path = self.dataset_path_list[episode_id]
|
||||
try:
|
||||
# print(dataset_path)
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
try: # some legacy data does not have this attribute
|
||||
is_sim = root.attrs["sim"]
|
||||
except:
|
||||
is_sim = False
|
||||
compressed = root.attrs.get("compress", False)
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
base_action = preprocess_base_action(base_action)
|
||||
action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
original_action_shape = action.shape
|
||||
episode_len = original_action_shape[0]
|
||||
# get observation at start_ts only
|
||||
qpos = root["/observations/qpos"][start_ts]
|
||||
qvel = root["/observations/qvel"][start_ts]
|
||||
image_dict = dict()
|
||||
for cam_name in self.camera_names:
|
||||
image_dict[cam_name] = root[f"/observations/images/{cam_name}"][
|
||||
start_ts
|
||||
]
|
||||
|
||||
if compressed:
|
||||
for cam_name in image_dict.keys():
|
||||
decompressed_image = cv2.imdecode(image_dict[cam_name], 1)
|
||||
image_dict[cam_name] = np.array(decompressed_image)
|
||||
|
||||
# get all actions after and including start_ts
|
||||
if is_sim:
|
||||
action = action[start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = action[
|
||||
max(0, start_ts - 1) :
|
||||
] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(
|
||||
0, start_ts - 1
|
||||
) # hack, to make timesteps more aligned
|
||||
|
||||
# self.is_sim = is_sim
|
||||
padded_action = np.zeros(
|
||||
(self.max_episode_len, original_action_shape[1]), dtype=np.float32
|
||||
)
|
||||
padded_action[:action_len] = action
|
||||
is_pad = np.zeros(self.max_episode_len)
|
||||
is_pad[action_len:] = 1
|
||||
|
||||
padded_action = padded_action[: self.chunk_size]
|
||||
is_pad = is_pad[: self.chunk_size]
|
||||
|
||||
# new axis for different cameras
|
||||
all_cam_images = []
|
||||
for cam_name in self.camera_names:
|
||||
all_cam_images.append(image_dict[cam_name])
|
||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||
|
||||
# construct observations
|
||||
image_data = torch.from_numpy(all_cam_images)
|
||||
qpos_data = torch.from_numpy(qpos).float()
|
||||
action_data = torch.from_numpy(padded_action).float()
|
||||
is_pad = torch.from_numpy(is_pad).bool()
|
||||
|
||||
# channel last
|
||||
image_data = torch.einsum("k h w c -> k c h w", image_data)
|
||||
|
||||
# augmentation
|
||||
if self.transformations is None:
|
||||
print("Initializing transformations")
|
||||
original_size = image_data.shape[2:]
|
||||
ratio = 0.95
|
||||
self.transformations = [
|
||||
transforms.RandomCrop(
|
||||
size=[
|
||||
int(original_size[0] * ratio),
|
||||
int(original_size[1] * ratio),
|
||||
]
|
||||
),
|
||||
transforms.Resize(original_size, antialias=True),
|
||||
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.4, saturation=0.5
|
||||
), # , hue=0.08)
|
||||
]
|
||||
|
||||
if self.augment_images:
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
# normalize image and change dtype to float
|
||||
image_data = image_data / 255.0
|
||||
|
||||
if self.policy_class == "Diffusion":
|
||||
# normalize to [-1, 1]
|
||||
action_data = (
|
||||
(action_data - self.norm_stats["action_min"])
|
||||
/ (self.norm_stats["action_max"] - self.norm_stats["action_min"])
|
||||
) * 2 - 1
|
||||
else:
|
||||
# normalize to mean 0 std 1
|
||||
action_data = (
|
||||
action_data - self.norm_stats["action_mean"]
|
||||
) / self.norm_stats["action_std"]
|
||||
|
||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
|
||||
"qpos_std"
|
||||
]
|
||||
|
||||
except:
|
||||
print(f"Error loading {dataset_path} in __getitem__")
|
||||
quit()
|
||||
|
||||
# print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype)
|
||||
return image_data, qpos_data, action_data, is_pad
|
||||
|
||||
|
||||
def get_norm_stats(dataset_path_list):
|
||||
all_qpos_data = []
|
||||
all_action_data = []
|
||||
all_episode_len = []
|
||||
|
||||
for dataset_path in dataset_path_list:
|
||||
try:
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
qpos = root["/observations/qpos"][()]
|
||||
qvel = root["/observations/qvel"][()]
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
# base_action = preprocess_base_action(base_action)
|
||||
# action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
except Exception as e:
|
||||
print(f"Error loading {dataset_path} in get_norm_stats")
|
||||
print(e)
|
||||
quit()
|
||||
all_qpos_data.append(torch.from_numpy(qpos))
|
||||
all_action_data.append(torch.from_numpy(action))
|
||||
all_episode_len.append(len(qpos))
|
||||
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||
all_action_data = torch.cat(all_action_data, dim=0)
|
||||
|
||||
# normalize action data
|
||||
action_mean = all_action_data.mean(dim=[0]).float()
|
||||
action_std = all_action_data.std(dim=[0]).float()
|
||||
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||
|
||||
# normalize qpos data
|
||||
qpos_mean = all_qpos_data.mean(dim=[0]).float()
|
||||
qpos_std = all_qpos_data.std(dim=[0]).float()
|
||||
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||
|
||||
action_min = all_action_data.min(dim=0).values.float()
|
||||
action_max = all_action_data.max(dim=0).values.float()
|
||||
|
||||
eps = 0.0001
|
||||
stats = {
|
||||
"action_mean": action_mean.numpy(),
|
||||
"action_std": action_std.numpy(),
|
||||
"action_min": action_min.numpy() - eps,
|
||||
"action_max": action_max.numpy() + eps,
|
||||
"qpos_mean": qpos_mean.numpy(),
|
||||
"qpos_std": qpos_std.numpy(),
|
||||
"example_qpos": qpos,
|
||||
}
|
||||
|
||||
return stats, all_episode_len
|
||||
|
||||
|
||||
def find_all_hdf5(dataset_dir, skip_mirrored_data):
|
||||
hdf5_files = []
|
||||
for root, dirs, files in os.walk(dataset_dir):
|
||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||
if "features" in filename:
|
||||
continue
|
||||
if skip_mirrored_data and "mirror" in filename:
|
||||
continue
|
||||
hdf5_files.append(os.path.join(root, filename))
|
||||
print(f"Found {len(hdf5_files)} hdf5 files")
|
||||
return hdf5_files
|
||||
|
||||
|
||||
def BatchSampler(batch_size, episode_len_l, sample_weights):
|
||||
sample_probs = (
|
||||
np.array(sample_weights) / np.sum(sample_weights)
|
||||
if sample_weights is not None
|
||||
else None
|
||||
)
|
||||
# print("BatchSampler", sample_weights)
|
||||
sum_dataset_len_l = np.cumsum(
|
||||
[0] + [np.sum(episode_len) for episode_len in episode_len_l]
|
||||
)
|
||||
while True:
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
|
||||
step_idx = np.random.randint(
|
||||
sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1]
|
||||
)
|
||||
batch.append(step_idx)
|
||||
yield batch
|
||||
|
||||
|
||||
def load_data(
|
||||
dataset_dir_l,
|
||||
name_filter,
|
||||
camera_names,
|
||||
batch_size_train,
|
||||
batch_size_val,
|
||||
chunk_size,
|
||||
skip_mirrored_data=False,
|
||||
load_pretrain=False,
|
||||
policy_class=None,
|
||||
stats_dir_l=None,
|
||||
sample_weights=None,
|
||||
train_ratio=0.99,
|
||||
):
|
||||
if type(dataset_dir_l) == str:
|
||||
dataset_dir_l = [dataset_dir_l]
|
||||
dataset_path_list_list = [
|
||||
find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l
|
||||
]
|
||||
num_episodes_0 = len(dataset_path_list_list[0])
|
||||
dataset_path_list = flatten_list(dataset_path_list_list)
|
||||
|
||||
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
|
||||
num_episodes_l = [
|
||||
len(dataset_path_list) for dataset_path_list in dataset_path_list_list
|
||||
]
|
||||
num_episodes_cumsum = np.cumsum(num_episodes_l)
|
||||
|
||||
# obtain train test split on dataset_dir_l[0]
|
||||
shuffled_episode_ids_0 = np.random.permutation(num_episodes_0)
|
||||
train_episode_ids_0 = shuffled_episode_ids_0[: int(train_ratio * num_episodes_0)]
|
||||
val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0) :]
|
||||
train_episode_ids_l = [train_episode_ids_0] + [
|
||||
np.arange(num_episodes) + num_episodes_cumsum[idx]
|
||||
for idx, num_episodes in enumerate(num_episodes_l[1:])
|
||||
]
|
||||
val_episode_ids_l = [val_episode_ids_0]
|
||||
train_episode_ids = np.concatenate(train_episode_ids_l)
|
||||
val_episode_ids = np.concatenate(val_episode_ids_l)
|
||||
print(
|
||||
f"\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n"
|
||||
)
|
||||
|
||||
# obtain normalization stats for qpos and action
|
||||
# if load_pretrain:
|
||||
# with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f:
|
||||
# norm_stats = pickle.load(f)
|
||||
# print('Loaded pretrain dataset stats')
|
||||
_, all_episode_len = get_norm_stats(dataset_path_list)
|
||||
train_episode_len_l = [
|
||||
[all_episode_len[i] for i in train_episode_ids]
|
||||
for train_episode_ids in train_episode_ids_l
|
||||
]
|
||||
val_episode_len_l = [
|
||||
[all_episode_len[i] for i in val_episode_ids]
|
||||
for val_episode_ids in val_episode_ids_l
|
||||
]
|
||||
|
||||
train_episode_len = flatten_list(train_episode_len_l)
|
||||
val_episode_len = flatten_list(val_episode_len_l)
|
||||
if stats_dir_l is None:
|
||||
stats_dir_l = dataset_dir_l
|
||||
elif type(stats_dir_l) == str:
|
||||
stats_dir_l = [stats_dir_l]
|
||||
norm_stats, _ = get_norm_stats(
|
||||
flatten_list(
|
||||
[find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l]
|
||||
)
|
||||
)
|
||||
print(f"Norm stats from: {stats_dir_l}")
|
||||
|
||||
batch_sampler_train = BatchSampler(
|
||||
batch_size_train, train_episode_len_l, sample_weights
|
||||
)
|
||||
batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None)
|
||||
|
||||
# print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}')
|
||||
|
||||
# construct dataset and dataloader
|
||||
train_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
train_episode_ids,
|
||||
train_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
val_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
val_episode_ids,
|
||||
val_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
train_num_workers = (
|
||||
(8 if os.getlogin() == "zfu" else 16) if train_dataset.augment_images else 2
|
||||
)
|
||||
val_num_workers = 8 if train_dataset.augment_images else 2
|
||||
print(
|
||||
f"Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}"
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler_train,
|
||||
pin_memory=True,
|
||||
num_workers=train_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_sampler=batch_sampler_val,
|
||||
pin_memory=True,
|
||||
num_workers=val_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
|
||||
|
||||
|
||||
def calibrate_linear_vel(base_action, c=None):
|
||||
if c is None:
|
||||
c = 0.0 # 0.19
|
||||
v = base_action[..., 0]
|
||||
w = base_action[..., 1]
|
||||
base_action = base_action.copy()
|
||||
base_action[..., 0] = v - c * w
|
||||
return base_action
|
||||
|
||||
|
||||
def smooth_base_action(base_action):
|
||||
return np.stack(
|
||||
[
|
||||
np.convolve(base_action[:, i], np.ones(5) / 5, mode="same")
|
||||
for i in range(base_action.shape[1])
|
||||
],
|
||||
axis=-1,
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def preprocess_base_action(base_action):
|
||||
# base_action = calibrate_linear_vel(base_action)
|
||||
base_action = smooth_base_action(base_action)
|
||||
|
||||
return base_action
|
||||
|
||||
|
||||
def postprocess_base_action(base_action):
|
||||
linear_vel, angular_vel = base_action
|
||||
linear_vel *= 1.0
|
||||
angular_vel *= 1.0
|
||||
# angular_vel = 0
|
||||
# if np.abs(linear_vel) < 0.05:
|
||||
# linear_vel = 0
|
||||
return np.array([linear_vel, angular_vel])
|
||||
|
||||
|
||||
### env utils
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
|
||||
|
||||
### helper functions
|
||||
|
||||
|
||||
def compute_dict_mean(epoch_dicts):
|
||||
result = {k: None for k in epoch_dicts[0]}
|
||||
num_items = len(epoch_dicts)
|
||||
for k in result:
|
||||
value_sum = 0
|
||||
for epoch_dict in epoch_dicts:
|
||||
value_sum += epoch_dict[k]
|
||||
result[k] = value_sum / num_items
|
||||
return result
|
||||
|
||||
|
||||
def detach_dict(d):
|
||||
new_d = dict()
|
||||
for k, v in d.items():
|
||||
new_d[k] = v.detach()
|
||||
return new_d
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
import yaml
|
||||
import time
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import collections
|
||||
import logging
|
||||
import dm_env
|
||||
import tracemalloc
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
bottom_image, _, _, _ = self.camera_bottom.read_frame(True, False, False, False)
|
||||
left_image, _, _, _ = self.camera_left.read_frame(True, False, False, False)
|
||||
right_image, _, _, _ = self.camera_right.read_frame(True, False, False, False)
|
||||
return {
|
||||
"cam_high": top_image,
|
||||
"cam_low": bottom_image,
|
||||
"cam_left": left_image,
|
||||
"cam_right": right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
# self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
_ = self.arm_left.set_joint_position(self.init_left_arm_angle[0:6])
|
||||
_ = self.arm_right.set_joint_position(self.init_right_arm_angle[0:6])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:6])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[7:13])
|
||||
self.arm_left.set_gripper_position(target_angle[6])
|
||||
self.arm_right.set_gripper_position(target_angle[13])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
aloha_config = config["robot_env"]
|
||||
device = DeviceAloha(aloha_config)
|
||||
device.reset()
|
||||
image_list = []
|
||||
tager_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
while True:
|
||||
tracemalloc.start() # 启动内存跟踪
|
||||
|
||||
tager_angle = np.array([angle + 0.1 if i not in [6, 13] else angle for i, angle in enumerate(tager_angle)])
|
||||
time_step = time.time()
|
||||
timestep = device.step(tager_angle)
|
||||
logging.info(f"Time: {time.time() - time_step}")
|
||||
image_list.append(timestep.observation["images"])
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics('lineno')
|
||||
# del timestep
|
||||
print("[ Top 10 ]")
|
||||
for stat in top_stats[:10]:
|
||||
print(stat)
|
||||
# logging.info(f"Images: {obs}")
|
||||
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
# import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import (
|
||||
load_data,
|
||||
sample_box_pose,
|
||||
sample_insertion_pose,
|
||||
compute_dict_mean,
|
||||
set_seed,
|
||||
detach_dict,
|
||||
)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print('daasdas')
|
||||
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from constants import DT
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||
|
||||
def load_hdf5(dataset_dir, dataset_name):
|
||||
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
action = root['/action'][()]
|
||||
image_dict = dict()
|
||||
for cam_name in root[f'/observations/images/'].keys():
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
|
||||
return qpos, qvel, action, image_dict
|
||||
|
||||
def main(args):
|
||||
dataset_dir = args['dataset_dir']
|
||||
episode_idx = args['episode_idx']
|
||||
dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||
|
||||
|
||||
def save_videos(video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1/dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for ts, image_dict in enumerate(video):
|
||||
images = []
|
||||
for cam_name in cam_names:
|
||||
image = image_dict[cam_name]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
images.append(image)
|
||||
images = np.concatenate(images, axis=1)
|
||||
out.write(images)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = []
|
||||
for cam_name in cam_names:
|
||||
all_cam_videos.append(video[cam_name])
|
||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
image = all_cam_videos[t]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
out.write(image)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
|
||||
|
||||
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||
if label_overwrite:
|
||||
label1, label2 = label_overwrite
|
||||
else:
|
||||
label1, label2 = 'State', 'Command'
|
||||
|
||||
qpos = np.array(qpos_list) # ts, dim
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
h, w = 2, num_dim
|
||||
num_figs = num_dim
|
||||
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
|
||||
|
||||
# plot joint state
|
||||
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
|
||||
# plot arm command
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.legend()
|
||||
|
||||
if ylim:
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(t_list, dataset_path):
|
||||
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
|
||||
h, w = 4, 10
|
||||
fig, axs = plt.subplots(2, 1, figsize=(w, h*2))
|
||||
# process t_list
|
||||
t_float = []
|
||||
for secs, nsecs in t_list:
|
||||
t_float.append(secs + nsecs * 10E-10)
|
||||
t_float = np.array(t_float)
|
||||
|
||||
ax = axs[0]
|
||||
ax.plot(np.arange(len(t_float)), t_float)
|
||||
ax.set_title(f'Camera frame timestamps')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
ax = axs[1]
|
||||
ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:])
|
||||
ax.set_title(f'dt')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
|
||||
main(vars(parser.parse_args()))
|
||||
Reference in New Issue
Block a user