From a2c181992a131c92e1930f3ebc3014112fe03625 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 3 Dec 2024 00:51:55 +0100 Subject: [PATCH] Refactor OpenX (#505) --- .../push_dataset_to_hub/openx/configs.yaml | 639 ------------- .../push_dataset_to_hub/openx/data_utils.py | 106 --- .../push_dataset_to_hub/openx/droid_utils.py | 200 ---- .../push_dataset_to_hub/openx/transforms.py | 859 ------------------ .../push_dataset_to_hub/openx_rlds_format.py | 143 +-- lerobot/scripts/push_dataset_to_hub.py | 30 +- 6 files changed, 58 insertions(+), 1919 deletions(-) delete mode 100644 lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml delete mode 100644 lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py delete mode 100644 lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py delete mode 100644 lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml b/lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml deleted file mode 100644 index f706270a..00000000 --- a/lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml +++ /dev/null @@ -1,639 +0,0 @@ -OPENX_DATASET_CONFIGS: - fractal20220817_data: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - base_pose_tool_reached - - gripper_closed - fps: 3 - - kuka: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - clip_function_input/base_pose_tool_reached - - gripper_closed - fps: 10 - - bridge_openx: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - EEF_state - - gripper_state - fps: 5 - - taco_play: - image_obs_keys: - - rgb_static - - rgb_gripper - depth_obs_keys: - - depth_static - - depth_gripper - state_obs_keys: - - state_eef - - state_gripper - fps: 15 - - jaco_play: - image_obs_keys: - - image - - image_wrist - depth_obs_keys: - - null - state_obs_keys: - - state_eef - - state_gripper - fps: 10 - - berkeley_cable_routing: - image_obs_keys: - - image - - top_image - - wrist45_image - - wrist225_image - depth_obs_keys: - - null - state_obs_keys: - - robot_state - fps: 10 - - roboturk: - image_obs_keys: - - front_rgb - depth_obs_keys: - - null - state_obs_keys: - - null - fps: 10 - - nyu_door_opening_surprising_effectiveness: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - null - fps: 3 - - viola: - image_obs_keys: - - agentview_rgb - - eye_in_hand_rgb - depth_obs_keys: - - null - state_obs_keys: - - joint_states - - gripper_states - fps: 20 - - berkeley_autolab_ur5: - image_obs_keys: - - image - - hand_image - depth_obs_keys: - - image_with_depth - state_obs_keys: - - state - fps: 5 - - toto: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 30 - - language_table: - image_obs_keys: - - rgb - depth_obs_keys: - - null - state_obs_keys: - - effector_translation - fps: 10 - - columbia_cairlab_pusht_real: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - robot_state - fps: 10 - - stanford_kuka_multimodal_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - depth_image - state_obs_keys: - - ee_position - - ee_orientation - fps: 20 - - nyu_rot_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 3 - - io_ai_tech: - image_obs_keys: - - image - - image_fisheye - - image_left_side - - image_right_side - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 3 - - stanford_hydra_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 10 - - austin_buds_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 20 - - nyu_franka_play_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - - image_additional_view - depth_obs_keys: - - depth - - depth_additional_view - state_obs_keys: - - eef_state - fps: 3 - - maniskill_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - depth - - wrist_depth - state_obs_keys: - - tcp_pose - - gripper_state - fps: 20 - - furniture_bench_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 10 - - cmu_franka_exploration_dataset_converted_externally_to_rlds: - image_obs_keys: - - highres_image - depth_obs_keys: - - null - state_obs_keys: - - null - fps: 10 - - ucsd_kitchen_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - joint_state - fps: 2 - - ucsd_pick_and_place_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 3 - - spoc: - image_obs_keys: - - image - - image_manipulation - depth_obs_keys: - - null - state_obs_keys: - - null - fps: 3 - - austin_sailor_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 20 - - austin_sirius_dataset_converted_externally_to_rlds: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 20 - - bc_z: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - present/xyz - - present/axis_angle - - present/sensed_close - fps: 10 - - utokyo_pr2_opening_fridge_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 10 - - utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 10 - - utokyo_xarm_pick_and_place_converted_externally_to_rlds: - image_obs_keys: - - image - - image2 - - hand_image - depth_obs_keys: - - null - state_obs_keys: - - end_effector_pose - fps: 10 - - utokyo_xarm_bimanual_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - pose_r - fps: 10 - - robo_net: - image_obs_keys: - - image - - image1 - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 1 - - robo_set: - image_obs_keys: - - image_left - - image_right - - image_wrist - depth_obs_keys: - - null - state_obs_keys: - - state - - state_velocity - fps: 5 - - berkeley_mvp_converted_externally_to_rlds: - image_obs_keys: - - hand_image - depth_obs_keys: - - null - state_obs_keys: - - gripper - - pose - - joint_pos - fps: 5 - - berkeley_rpt_converted_externally_to_rlds: - image_obs_keys: - - hand_image - depth_obs_keys: - - null - state_obs_keys: - - joint_pos - - gripper - fps: 30 - - kaist_nonprehensile_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 10 - - stanford_mask_vit_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - - tokyo_u_lsmo_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 10 - - dlr_sara_pour_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 10 - - dlr_sara_grid_clamp_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 10 - - dlr_edan_shared_control_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 5 - - asu_table_top_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 12.5 - - stanford_robocook_converted_externally_to_rlds: - image_obs_keys: - - image_1 - - image_2 - depth_obs_keys: - - depth_1 - - depth_2 - state_obs_keys: - - eef_state - - gripper_state - fps: 5 - - imperialcollege_sawyer_wrist_cam: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 10 - - iamlab_cmu_pickup_insert_converted_externally_to_rlds: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - joint_state - - gripper_state - fps: 20 - - uiuc_d3field: - image_obs_keys: - - image_1 - - image_2 - depth_obs_keys: - - depth_1 - - depth_2 - state_obs_keys: - - null - fps: 1 - - utaustin_mutex: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 20 - - berkeley_fanuc_manipulation: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - joint_state - - gripper_state - fps: 10 - - cmu_playing_with_food: - image_obs_keys: - - image - - finger_vision_1 - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 10 - - cmu_play_fusion: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 5 - - cmu_stretch: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - eef_state - - gripper_state - fps: 10 - - berkeley_gnm_recon: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - - position - - yaw - fps: 3 - - berkeley_gnm_cory_hall: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - - position - - yaw - fps: 5 - - berkeley_gnm_sac_son: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - state - - position - - yaw - fps: 10 - - droid: - image_obs_keys: - - exterior_image_1_left - - exterior_image_2_left - - wrist_image_left - depth_obs_keys: - - null - state_obs_keys: - - proprio - fps: 15 - - droid_100: - image_obs_keys: - - exterior_image_1_left - - exterior_image_2_left - - wrist_image_left - depth_obs_keys: - - null - state_obs_keys: - - proprio - fps: 15 - - fmb: - image_obs_keys: - - image_side_1 - - image_side_2 - - image_wrist_1 - - image_wrist_2 - depth_obs_keys: - - image_side_1_depth - - image_side_2_depth - - image_wrist_1_depth - - image_wrist_2_depth - state_obs_keys: - - proprio - fps: 10 - - dobbe: - image_obs_keys: - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - proprio - fps: 3.75 - - usc_cloth_sim_converted_externally_to_rlds: - image_obs_keys: - - image - depth_obs_keys: - - null - state_obs_keys: - - null - fps: 10 - - plex_robosuite: - image_obs_keys: - - image - - wrist_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 20 - - conq_hose_manipulation: - image_obs_keys: - - frontleft_fisheye_image - - frontright_fisheye_image - - hand_color_image - depth_obs_keys: - - null - state_obs_keys: - - state - fps: 30 diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py b/lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py deleted file mode 100644 index 1582c67c..00000000 --- a/lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the Licens e. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -NOTE(YL): Adapted from: - Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py - -data_utils.py - -Additional utils for data processing. -""" - -from typing import Any, Dict, List - -import tensorflow as tf - - -def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor: - """ - Converts gripper actions from continuous to binary values (0 and 1). - - We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it - transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate - values based on the state that is reached _after_ those intermediate values. - - In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that - chunk of intermediate values as the last action in the trajectory. - - The `scan_fn` implements the following logic: - new_actions = np.empty_like(actions) - carry = actions[-1] - for i in reversed(range(actions.shape[0])): - if in_between_mask[i]: - carry = carry - else: - carry = float(open_mask[i]) - new_actions[i] = carry - """ - open_mask, closed_mask = actions > 0.95, actions < 0.05 - in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask)) - is_open_float = tf.cast(open_mask, tf.float32) - - def scan_fn(carry, i): - return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i]) - - return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True) - - -def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor: - return 1 - actions - - -def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor: - """ - Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open). - - Assumes that the first relative gripper is not redundant (i.e. close when already closed)! - """ - # Note =>> -1 for closing, 1 for opening, 0 for no change - opening_mask, closing_mask = actions < -0.1, actions > 0.1 - thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0)) - - def scan_fn(carry, i): - return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i]) - - # If no relative grasp, assumes open for whole trajectory - start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)] - start = tf.cond(start == 0, lambda: 1, lambda: start) - - # Note =>> -1 for closed, 1 for open - new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start) - new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5 - - return new_actions - - -# === Bridge-V2 =>> Dataset-Specific Transform === -def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]: - """Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" - movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6] - traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) - traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1) - - return traj_truncated - - -# === RLDS Dataset Initialization Utilities === -def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None: - print("\n######################################################################################") - print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #") - for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False): - pad = 80 - len(dataset_kwargs["name"]) - print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #") - print("######################################################################################\n") diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py b/lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py deleted file mode 100644 index 22ac4d9e..00000000 --- a/lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py +++ /dev/null @@ -1,200 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -NOTE(YL): Adapted from: - OpenVLA: https://github.com/openvla/openvla - -Episode transforms for DROID dataset. -""" - -from typing import Any, Dict - -import tensorflow as tf -import tensorflow_graphics.geometry.transformation as tfg - - -def rmat_to_euler(rot_mat): - return tfg.euler.from_rotation_matrix(rot_mat) - - -def euler_to_rmat(euler): - return tfg.rotation_matrix_3d.from_euler(euler) - - -def invert_rmat(rot_mat): - return tfg.rotation_matrix_3d.inverse(rot_mat) - - -def rotmat_to_rot6d(mat): - """ - Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). - Args: - mat: rotation matrix - - Returns: 6d vector (first two rows of rotation matrix) - - """ - r6 = mat[..., :2, :] - r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] - r6_flat = tf.concat([r6_0, r6_1], axis=-1) - return r6_flat - - -def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): - """ - Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. - Args: - velocity: 6d velocity action (3 x translation, 3 x rotation) - wrist_in_robot_frame: 6d pose of the end-effector in robot base frame - - Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) - - """ - r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) - r_frame_inv = invert_rmat(r_frame) - - # world to wrist: dT_pi = R^-1 dT_rbt - vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0] - - # world to wrist: dR_pi = R^-1 dR_rbt R - dr_ = euler_to_rmat(velocity[:, 3:6]) - dr_ = r_frame_inv @ (dr_ @ r_frame) - dr_r6 = rotmat_to_rot6d(dr_) - return tf.concat([vel_t, dr_r6], axis=-1) - - -def rand_swap_exterior_images(img1, img2): - """ - Randomly swaps the two exterior images (for training with single exterior input). - """ - return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) - - -def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - DROID dataset transformation for actions expressed in *base* frame of the robot. - """ - dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] - dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] - - trajectory["action"] = tf.concat( - ( - dt, - dr_, - 1 - trajectory["action_dict"]["gripper_position"], - ), - axis=-1, - ) - trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( - rand_swap_exterior_images( - trajectory["observation"]["exterior_image_1_left"], - trajectory["observation"]["exterior_image_2_left"], - ) - ) - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["cartesian_position"], - trajectory["observation"]["gripper_position"], - ), - axis=-1, - ) - return trajectory - - -def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - DROID dataset transformation for actions expressed in *wrist* frame of the robot. - """ - wrist_act = velocity_act_to_wrist_frame( - trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] - ) - trajectory["action"] = tf.concat( - ( - wrist_act, - trajectory["action_dict"]["gripper_position"], - ), - axis=-1, - ) - trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( - rand_swap_exterior_images( - trajectory["observation"]["exterior_image_1_left"], - trajectory["observation"]["exterior_image_2_left"], - ) - ) - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["cartesian_position"], - trajectory["observation"]["gripper_position"], - ), - axis=-1, - ) - return trajectory - - -def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - DROID dataset transformation for actions expressed in *base* frame of the robot. - """ - dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] - dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] - trajectory["action"] = tf.concat( - ( - dt, - dr_, - 1 - trajectory["action_dict"]["gripper_position"], - ), - axis=-1, - ) - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["cartesian_position"], - trajectory["observation"]["gripper_position"], - ), - axis=-1, - ) - return trajectory - - -def zero_action_filter(traj: Dict) -> bool: - """ - Filters transitions whose actions are all-0 (only relative actions, no gripper action). - Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". - """ - droid_q01 = tf.convert_to_tensor( - [ - -0.7776297926902771, - -0.5803514122962952, - -0.5795090794563293, - -0.6464047729969025, - -0.7041108310222626, - -0.8895104378461838, - ] - ) - droid_q99 = tf.convert_to_tensor( - [ - 0.7597932070493698, - 0.5726242214441299, - 0.7351000607013702, - 0.6705610305070877, - 0.6464948207139969, - 0.8897542208433151, - ] - ) - droid_norm_0_act = ( - 2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1 - ) - - return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5) diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py b/lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py deleted file mode 100644 index a0c1e30f..00000000 --- a/lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py +++ /dev/null @@ -1,859 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -NOTE(YL): Adapted from: - OpenVLA: https://github.com/openvla/openvla - Octo: https://github.com/octo-models/octo - -transforms.py - -Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment. - -Transforms adopt the following structure: - Input: Dictionary of *batched* features (i.e., has leading time dimension) - Output: Dictionary `step` =>> { - "observation": { - - State (in chosen state representation) - }, - "action": Action (in chosen action representation), - "language_instruction": str - } -""" - -from typing import Any, Dict - -import tensorflow as tf - -from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import ( - binarize_gripper_actions, - invert_gripper_actions, - rel2abs_gripper_actions, - relabel_bridge_actions, -) - - -def droid_baseact_transform_fn(): - from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform - - return droid_baseact_transform - - -def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - Applies to version of Bridge V2 in Open X-Embodiment mixture. - - Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! - """ - for key in trajectory: - if key == "traj_metadata": - continue - elif key in ["observation", "action"]: - for key2 in trajectory[key]: - trajectory[key][key2] = trajectory[key][key2][1:] - else: - trajectory[key] = trajectory[key][1:] - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - trajectory = relabel_bridge_actions(trajectory) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - Applies to original version of Bridge V2 from the official project website. - - Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! - """ - for key in trajectory: - if key == "traj_metadata": - continue - elif key == "observation": - for key2 in trajectory[key]: - trajectory[key][key2] = trajectory[key][key2][1:] - else: - trajectory[key] = trajectory[key][1:] - - trajectory["action"] = tf.concat( - [ - trajectory["action"][:, :6], - binarize_gripper_actions(trajectory["action"][:, -1])[:, None], - ], - axis=1, - ) - trajectory = relabel_bridge_actions(trajectory) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - [ - trajectory["action"][:, :6], - binarize_gripper_actions(trajectory["action"][:, -1])[:, None], - ], - axis=1, - ) - trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] - return trajectory - - -def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - # decode compressed state - eef_value = tf.io.decode_compressed( - trajectory["observation"]["clip_function_input/base_pose_tool_reached"], - compression_type="ZLIB", - ) - eef_value = tf.io.decode_raw(eef_value, tf.float32) - trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) - gripper_value = tf.io.decode_compressed( - trajectory["observation"]["gripper_closed"], compression_type="ZLIB" - ) - gripper_value = tf.io.decode_raw(gripper_value, tf.float32) - trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6] - trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8] - trajectory["action"] = trajectory["action"]["rel_actions_world"] - - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - tf.clip_by_value(trajectory["action"][:, -1:], 0, 1), - ), - axis=-1, - ) - - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6] - trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][ - :, -1: - ] - - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - tf.zeros_like(trajectory["action"]["world_vector"]), - gripper_action[:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - tf.zeros_like(trajectory["action"]["world_vector"][:, :1]), - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert absolute gripper action, +1 = open, 0 = close - gripper_action = invert_gripper_actions( - tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1) - ) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action, - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"] - return trajectory - - -def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # make gripper action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"][:, None] - gripper_action = tf.clip_by_value(gripper_action, 0, 1) - gripper_action = invert_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action, - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14] - - # make gripper action absolute action, +1 = open, 0 = close - gripper_action = trajectory["action"]["gripper_closedness_action"] - gripper_action = rel2abs_gripper_actions(gripper_action) - - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - gripper_action[:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # default to "open" gripper - trajectory["action"] = tf.concat( - ( - trajectory["action"], - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"]), - tf.ones_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - - # decode language instruction - instruction_bytes = trajectory["observation"]["instruction"] - instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8") - # Remove trailing padding --> convert RaggedTensor to regular Tensor. - trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[ - :, 0 - ] - return trajectory - - -def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["world_vector"], - trajectory["action"]["rotation_delta"], - trajectory["action"]["gripper_closedness_action"][:, None], - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tf.zeros_like(trajectory["action"][:, :3]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:] - trajectory["action"] = trajectory["action"][..., :7] - return trajectory - - -def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(trajectory["action"][:, -1:]), - ), - axis=-1, - ) - - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :3], - trajectory["observation"]["state"][:, 7:10], - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2] - return trajectory - - -def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] - return trajectory - - -def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32) - trajectory["observation"]["depth_additional_view"] = tf.cast( - trajectory["observation"]["depth_additional_view"][..., 0], tf.float32 - ) - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:] - - # clip gripper action, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, -8:-2], - tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1), - ), - axis=-1, - ) - return trajectory - - -def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8] - return trajectory - - -def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - import tensorflow_graphics.geometry.transformation as tft - - trajectory["observation"]["state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :7], - trajectory["observation"]["state"][:, -1:], - ), - axis=-1, - ) - - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tft.euler.from_quaternion(trajectory["action"][:, 3:7]), - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - return trajectory - - -def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tf.zeros_like(trajectory["action"][:, :3]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - return trajectory - - -def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - return trajectory - - -def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"]["future/xyz_residual"][:, :3], - trajectory["action"]["future/axis_angle_residual"][:, :3], - invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)), - ), - axis=-1, - ) - trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] - return trajectory - - -def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = trajectory["action"][..., -7:] - return trajectory - - -def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :4], - tf.zeros_like(trajectory["observation"]["state"][:, :2]), - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :4], - tf.zeros_like(trajectory["action"][:, :2]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - """ - trajectory["observation"]["state"] = tf.concat(( - tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32), - trajectory["observation"]["pose"], - trajectory["observation"]["joint_pos"],), - axis=-1,) - """ - trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32) - return trajectory - - -def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32) - return trajectory - - -def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - tf.zeros_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - return trajectory - - -def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["end_effector_pose"][:, :4], - tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]), - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :4], - tf.zeros_like(trajectory["action"][:, :2]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6] - return trajectory - - -def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # invert gripper action, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(trajectory["action"][:, -1:]), - ), - axis=-1, - ) - return trajectory - - -def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - return trajectory - - -def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - import tensorflow_graphics.geometry.transformation as tft - - trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8] - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tft.euler.from_quaternion(trajectory["action"][:, 3:7]), - trajectory["action"][:, 7:8], - ), - axis=-1, - ) - return trajectory - - -def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"], - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - return trajectory - - -def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] - - # invert gripper action + clip, +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :6], - invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), - ), - axis=-1, - ) - return trajectory - - -def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6] - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7] - - # dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close - trajectory["action"] = tf.concat( - ( - trajectory["action"], - invert_gripper_actions(trajectory["observation"]["gripper_state"]), - ), - axis=-1, - ) - return trajectory - - -def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - import tensorflow_graphics.geometry.transformation as tft - - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - tft.euler.from_quaternion(trajectory["action"][:, 3:7]), - trajectory["action"][:, -1:], - ), - axis=-1, - ) - return trajectory - - -def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :3], - trajectory["action"][:, -4:], - ), - axis=-1, - ) - return trajectory - - -def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["eef_state"] = tf.concat( - ( - trajectory["observation"]["state"][:, :3], - tf.zeros_like(trajectory["observation"]["state"][:, :3]), - ), - axis=-1, - ) - trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] - trajectory["action"] = trajectory["action"][..., :-1] - return trajectory - - -def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - trajectory["observation"]["state"] = tf.concat( - ( - trajectory["observation"]["position"], - tf.zeros_like(trajectory["observation"]["state"][:, :3]), - trajectory["observation"]["yaw"], - ), - axis=-1, - ) - trajectory["action"] = tf.concat( - ( - trajectory["action"], - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"]), - tf.zeros_like(trajectory["action"][:, :1]), - ), - axis=-1, - ) - return trajectory - - -def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # every input feature is batched, ie has leading batch dimension - trajectory["observation"]["proprio"] = tf.concat( - ( - trajectory["observation"]["eef_pose"], - trajectory["observation"]["state_gripper_pose"][..., None], - ), - axis=-1, - ) - return trajectory - - -def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # every input feature is batched, ie has leading batch dimension - trajectory["observation"]["proprio"] = trajectory["observation"]["state"] - return trajectory - - -def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - # gripper action is in -1...1 --> clip to 0...1, flip - gripper_action = trajectory["action"][:, -1:] - gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) - - trajectory["action"] = tf.concat( - ( - trajectory["action"][:, :7], - gripper_action, - ), - axis=-1, - ) - return trajectory - - -def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: - return trajectory - - -# === Registry === -OPENX_STANDARDIZATION_TRANSFORMS = { - "bridge_openx": bridge_openx_dataset_transform, - "bridge_orig": bridge_orig_dataset_transform, - "bridge_dataset": bridge_orig_dataset_transform, - "ppgm": ppgm_dataset_transform, - "ppgm_static": ppgm_dataset_transform, - "ppgm_wrist": ppgm_dataset_transform, - "fractal20220817_data": rt1_dataset_transform, - "kuka": kuka_dataset_transform, - "taco_play": taco_play_dataset_transform, - "jaco_play": jaco_play_dataset_transform, - "berkeley_cable_routing": berkeley_cable_routing_dataset_transform, - "roboturk": roboturk_dataset_transform, - "nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform, - "viola": viola_dataset_transform, - "berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform, - "toto": toto_dataset_transform, - "language_table": language_table_dataset_transform, - "columbia_cairlab_pusht_real": pusht_dataset_transform, - "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform, - "nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform, - "stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform, - "austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform, - "nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform, - "maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform, - "furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform, - "cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform, - "ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform, - "ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform, - "austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform, - "austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform, - "bc_z": bc_z_dataset_transform, - "utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform, - "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform, - "utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform, - "utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform, - "robo_net": robo_net_dataset_transform, - "berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform, - "berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform, - "kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform, - "stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform, - "tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform, - "dlr_sara_pour_converted_externally_to_rlds": identity_transform, - "dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform, - "dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform, - "asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform, - "stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform, - "imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform, - "iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform, - "uiuc_d3field": uiuc_d3field_dataset_transform, - "utaustin_mutex": utaustin_mutex_dataset_transform, - "berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform, - "cmu_playing_with_food": cmu_playing_with_food_dataset_transform, - "cmu_play_fusion": playfusion_dataset_transform, - "cmu_stretch": cmu_stretch_dataset_transform, - "berkeley_gnm_recon": gnm_dataset_transform, - "berkeley_gnm_cory_hall": gnm_dataset_transform, - "berkeley_gnm_sac_son": gnm_dataset_transform, - "droid": droid_baseact_transform_fn(), - "droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid - "fmb": fmb_transform, - "dobbe": dobbe_dataset_transform, - "robo_set": robo_set_dataset_transform, - "usc_cloth_sim_converted_externally_to_rlds": identity_transform, - "plex_robosuite": identity_transform, - "conq_hose_manipulation": identity_transform, - "io_ai_tech": identity_transform, - "spoc": identity_transform, -} diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py b/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py index cfe11503..1f8a5d14 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py @@ -14,13 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +For all datasets in the RLDS format. For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets. +NOTE: You need to install tensorflow and tensorflow_datsets before running this script. + Example: python lerobot/scripts/push_dataset_to_hub.py \ - --raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \ - --repo-id youliangtan/sampled_bridge_data_v2 \ - --raw-format openx_rlds.bridge_orig \ + --raw-dir /path/to/data/bridge_dataset/1.0.0/ \ + --repo-id your_hub/sampled_bridge_data_v2 \ + --raw-format rlds \ --episodes 3 4 5 8 9 Exact dataset fps defined in openx/config.py, obtained from: @@ -35,12 +38,10 @@ import tensorflow as tf import tensorflow_datasets as tfds import torch import tqdm -import yaml from datasets import Dataset, Features, Image, Sequence, Value from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION -from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS from lerobot.common.datasets.push_dataset_to_hub.utils import ( calculate_episode_data_index, concatenate_episodes, @@ -52,11 +53,6 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames -with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml") as f: - _openx_list = yaml.safe_load(f) - -OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"] - np.set_printoptions(precision=2) @@ -108,7 +104,6 @@ def load_from_raw( video: bool, episodes: list[int] | None = None, encoding: dict | None = None, - openx_dataset_name: str | None = None, ): """ Args: @@ -136,16 +131,17 @@ def load_from_raw( # we will apply the standardization transform if the dataset_name is provided # if the dataset name is not provided and the goal is to convert any rlds formatted dataset # search for 'image' keys in the observations - if openx_dataset_name is not None: - print(" - applying standardization transform for dataset: ", openx_dataset_name) - assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS - transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name] - dataset = dataset.map(transform_fn) - - image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"] - else: - obs_keys = dataset_info.features["steps"]["observation"].keys() - image_keys = [key for key in obs_keys if "image" in key] + image_keys = [] + state_keys = [] + observation_info = dataset_info.features["steps"]["observation"] + for key in observation_info: + # check whether the key is for an image or a vector observation + if len(observation_info[key].shape) == 3: + # only adding uint8 images discards depth images + if observation_info[key].dtype == tf.uint8: + image_keys.append(key) + else: + state_keys.append(key) lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None @@ -193,50 +189,31 @@ def load_from_raw( num_frames = episode["action"].shape[0] - ########################################################### - # Handle the episodic data - - # last step of demonstration is considered done - done = torch.zeros(num_frames, dtype=torch.bool) - done[-1] = True ep_dict = {} - langs = [] # TODO: might be located in "observation" + for key in state_keys: + ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key]) - image_array_dict = {key: [] for key in image_keys} - - # We will create the state observation tensor by stacking the state - # obs keys defined in the openx/configs.py - if openx_dataset_name is not None: - state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"] - # stack the state observations, if is None, pad with zeros - states = [] - for key in state_obs_keys: - if key in episode["observation"]: - states.append(tf_to_torch(episode["observation"][key])) - else: - states.append(torch.zeros(num_frames, 1)) # pad with zeros - states = torch.cat(states, dim=1) - # assert states.shape == (num_frames, 8), f"states shape: {states.shape}" - else: - states = tf_to_torch(episode["observation"]["state"]) - - actions = tf_to_torch(episode["action"]) - rewards = tf_to_torch(episode["reward"]).float() + ep_dict["action"] = tf_to_torch(episode["action"]) + ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float() + ep_dict["next.done"] = tf_to_torch(episode["is_last"]) + ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"]) + ep_dict["is_first"] = tf_to_torch(episode["is_first"]) + ep_dict["discount"] = tf_to_torch(episode["discount"]) # If lang_key is present, convert the entire tensor at once if lang_key is not None: - langs = [str(x) for x in episode[lang_key]] + ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]] + + ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps + ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) + ep_dict["frame_index"] = torch.arange(0, num_frames, 1) + + image_array_dict = {key: [] for key in image_keys} for im_key in image_keys: imgs = episode["observation"][im_key] image_array_dict[im_key] = [tf_img_convert(img) for img in imgs] - # simple assertions - for item in [states, actions, rewards, done]: - assert len(item) == num_frames - - ########################################################### - # loop through all cameras for im_key in image_keys: img_key = f"observation.images.{im_key}" @@ -262,17 +239,6 @@ def load_from_raw( else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] - if lang_key is not None: - ep_dict["language_instruction"] = langs - - ep_dict["observation.state"] = states - ep_dict["action"] = actions - ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps - ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) - ep_dict["frame_index"] = torch.arange(0, num_frames, 1) - ep_dict["next.reward"] = rewards - ep_dict["next.done"] = done - path_ep_dict = tmp_ep_dicts_dir.joinpath( "ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt" ) @@ -290,30 +256,28 @@ def load_from_raw( def to_hf_dataset(data_dict, video) -> Dataset: features = {} - keys = [key for key in data_dict if "observation.images." in key] - for key in keys: - if video: - features[key] = VideoFrame() - else: - features[key] = Image() + for key in data_dict: + # check if vector state obs + if key.startswith("observation.") and "observation.images." not in key: + features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)) + # check if image obs + elif "observation.images." in key: + if video: + features[key] = VideoFrame() + else: + features[key] = Image() - features["observation.state"] = Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) - ) - if "observation.velocity" in data_dict: - features["observation.velocity"] = Sequence( - length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None) - ) - if "observation.effort" in data_dict: - features["observation.effort"] = Sequence( - length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None) - ) if "language_instruction" in data_dict: features["language_instruction"] = Value(dtype="string", id=None) features["action"] = Sequence( length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) ) + + features["is_terminal"] = Value(dtype="bool", id=None) + features["is_first"] = Value(dtype="bool", id=None) + features["discount"] = Value(dtype="float32", id=None) + features["episode_index"] = Value(dtype="int64", id=None) features["frame_index"] = Value(dtype="int64", id=None) features["timestamp"] = Value(dtype="float32", id=None) @@ -333,19 +297,8 @@ def from_raw_to_lerobot_format( video: bool = True, episodes: list[int] | None = None, encoding: dict | None = None, - openx_dataset_name: str | None = None, ): - """This is a test impl for rlds conversion""" - if openx_dataset_name is None: - # set a default rlds frame rate if the dataset is not from openx - fps = 30 - elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]: - raise ValueError( - "fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested" - ) - fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"] - - data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding) hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) info = { diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 2bb641a4..0233ede6 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -66,7 +66,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str): from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format elif raw_format == "aloha_hdf5": from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format - elif "openx_rlds" in raw_format: + elif raw_format in ["rlds", "openx"]: from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format elif raw_format == "dora_parquet": from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format @@ -204,24 +204,14 @@ def push_dataset_to_hub( # convert dataset from original raw format to LeRobot format from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) - fmt_kwgs = { - "raw_dir": raw_dir, - "videos_dir": videos_dir, - "fps": fps, - "video": video, - "episodes": episodes, - "encoding": encoding, - } - - if "openx_rlds." in raw_format: - # Support for official OXE dataset name inside `raw_format`. - # For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?), - # and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating - _, openx_dataset_name = raw_format.split(".") - print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.") - fmt_kwgs["openx_dataset_name"] = openx_dataset_name - - hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs) + hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( + raw_dir, + videos_dir, + fps, + video, + episodes, + encoding, + ) lerobot_dataset = LeRobotDataset.from_preloaded( repo_id=repo_id, @@ -290,7 +280,7 @@ def main(): "--raw-format", type=str, required=True, - help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).", + help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).", ) parser.add_argument( "--repo-id",