initial commit, with UR env connected to sim backend
This commit is contained in:
154
examples/ur_sim/env.py
Normal file
154
examples/ur_sim/env.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import argparse
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
logging.getLogger('gymnasium').setLevel(logging.ERROR)
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
from omni.isaac.lab.app import AppLauncher
|
||||
|
||||
# add argparse arguments
|
||||
parser = argparse.ArgumentParser(description="Tutorial on using the differential IK controller.")
|
||||
# append AppLauncher cli args
|
||||
AppLauncher.add_app_launcher_args(parser)
|
||||
# parse the arguments
|
||||
|
||||
args_cli, other_args = parser.parse_known_args()
|
||||
sys.argv = [sys.argv[0]] + other_args # clear out sys.argv for hydra
|
||||
|
||||
|
||||
|
||||
# launch omniverse app
|
||||
args_cli.enable_cameras = True
|
||||
# args_cli.headless = True
|
||||
args_cli.headless = False
|
||||
app_launcher = AppLauncher(args_cli)
|
||||
simulation_app = app_launcher.app
|
||||
|
||||
"""Rest everything follows."""
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import torch
|
||||
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
import real2simeval.environments
|
||||
from real2simeval.splat_render.render import SplatRenderer
|
||||
from real2simeval.utils import get_transform_from_txt, scalar_last, decrease_brightness
|
||||
|
||||
from omni.isaac.lab_tasks.utils import parse_env_cfg
|
||||
from omni.isaac.core.prims import GeometryPrimView
|
||||
import omni.isaac.lab.utils.math as math
|
||||
|
||||
|
||||
DATA_PATH = Path(__file__).parent.parent.parent.parent.parent / "data"
|
||||
|
||||
class URSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self.file = h5py.File("data/episode.h5", "r")
|
||||
self.step = 0
|
||||
|
||||
env_cfg = parse_env_cfg(
|
||||
task,
|
||||
device= args_cli.device,
|
||||
num_envs=1,
|
||||
use_fabric=True,
|
||||
)
|
||||
|
||||
sim_assets = {
|
||||
"pi_scene_v2_static": DATA_PATH/"pi_scene_v2",
|
||||
"bottle": DATA_PATH/"pi_objects/bottle",
|
||||
"plate": DATA_PATH/"pi_objects/plate",
|
||||
"robot": DATA_PATH/"pi_robot/",
|
||||
}
|
||||
env_cfg.setup_scene(sim_assets)
|
||||
|
||||
|
||||
self._gym = gymnasium.make(task, cfg = env_cfg)
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
action = action["actions"]
|
||||
|
||||
# ur5e = self.file["observation/ur5e/joints/position"][self.step]
|
||||
# robotiq = self.file["observation/robotiq_gripper/gripper/position"][self.step]
|
||||
# action = np.concatenate([ur5e, robotiq], axis=-1)
|
||||
|
||||
# scale gripper from [0,1] to [-1,1]
|
||||
action = action.copy()
|
||||
action[-1] = action[-1] * 2 - 1
|
||||
|
||||
action = torch.tensor(action, dtype=torch.float32)[None]
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action)
|
||||
|
||||
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
# self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
img1 = self._last_obs["observation/base_0_camera/rgb/image"]
|
||||
img2 = self._last_obs["observation/wrist_0_camera/rgb/image"]
|
||||
big_img = np.concatenate([img1, img2], axis=1)
|
||||
cv2.imshow("big_img", cv2.cvtColor(big_img, cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
self.step += 1
|
||||
|
||||
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
# img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1))
|
||||
data = {}
|
||||
data["observation/ur5e/joints/position"] = gym_obs["policy"]["joints"][:6].detach().cpu().numpy()
|
||||
data["observation/robotiq_gripper/gripper/position"] = gym_obs["policy"]["joints"][6:].detach().cpu().numpy()
|
||||
data["observation/base_0_camera/rgb/image"] = gym_obs["splat"]["base_cam"]
|
||||
data["observation/wrist_0_camera/rgb/image"] = gym_obs["splat"]["wrist_cam"]
|
||||
|
||||
# data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_224_224"][self.step])
|
||||
# data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_224_224"][self.step])
|
||||
# data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_256_320"][self.step])
|
||||
# data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_256_320"][self.step])
|
||||
# data["observation/ur5e/joints/position"] = self.file["observation/ur5e/joints/position"][self.step]
|
||||
# data["observation/robotiq_gripper/gripper/position"] = self.file["observation/robotiq_gripper/gripper/position"][self.step]
|
||||
#
|
||||
# print(data["observation/ur5e/joints/position"])
|
||||
|
||||
return data
|
||||
|
||||
228
examples/ur_sim/env.py.back
Normal file
228
examples/ur_sim/env.py.back
Normal file
@@ -0,0 +1,228 @@
|
||||
import argparse
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
logging.getLogger('gymnasium').setLevel(logging.ERROR)
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
from omni.isaac.lab.app import AppLauncher
|
||||
|
||||
# add argparse arguments
|
||||
parser = argparse.ArgumentParser(description="Tutorial on using the differential IK controller.")
|
||||
# append AppLauncher cli args
|
||||
AppLauncher.add_app_launcher_args(parser)
|
||||
# parse the arguments
|
||||
|
||||
args_cli, other_args = parser.parse_known_args()
|
||||
sys.argv = [sys.argv[0]] + other_args # clear out sys.argv for hydra
|
||||
|
||||
|
||||
|
||||
# launch omniverse app
|
||||
args_cli.enable_cameras = True
|
||||
args_cli.headless = True
|
||||
app_launcher = AppLauncher(args_cli)
|
||||
simulation_app = app_launcher.app
|
||||
|
||||
"""Rest everything follows."""
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import torch
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
import real2simeval.environments
|
||||
from real2simeval.splat_render.render import SplatRenderer
|
||||
from real2simeval.utils import get_transform_from_txt, scalar_last, decrease_brightness
|
||||
|
||||
from omni.isaac.lab_tasks.utils import parse_env_cfg
|
||||
from omni.isaac.core.prims import GeometryPrimView
|
||||
import omni.isaac.lab.utils.math as math
|
||||
|
||||
|
||||
class URSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self.file = h5py.File("data/episode.h5", "r")
|
||||
self.step = 0
|
||||
|
||||
env_cfg = parse_env_cfg(
|
||||
task,
|
||||
device= args_cli.device,
|
||||
num_envs=1,
|
||||
use_fabric=True,
|
||||
)
|
||||
self._gym = gymnasium.make(task, cfg = env_cfg)
|
||||
|
||||
|
||||
splats = {
|
||||
"pi_scene_v2": "./data/pi_scene_v2/splat.ply",
|
||||
"bottle": "./data/pi_objects/bottle/splat.ply",
|
||||
"plate": "./data/pi_objects/plate/splat.ply",
|
||||
}
|
||||
views = {}
|
||||
robot = Path("./data/pi_robot/SEGMENTED/")
|
||||
for ply in robot.glob("*.ply"):
|
||||
splats[ply.stem] = str(ply)
|
||||
path = ply.stem.replace("-", "/")
|
||||
view = GeometryPrimView(
|
||||
prim_paths_expr=f"/World/envs/env_.*/robot/{path}",
|
||||
)
|
||||
views[ply.stem] = view
|
||||
|
||||
|
||||
splat_renderer = SplatRenderer(splats=splats)
|
||||
|
||||
splat_renderer.init_cameras({
|
||||
"hand_cam": { "fovy": 1.04, "fovx": 1.33, "res": (480, 640) },
|
||||
"third_person_cam": { "fovy": 1.04, "fovx": 1.33, "res": (480, 640) },
|
||||
# "hand_cam": { "fovy": 0.7925, "fovx": 1.01, "res": (480, 640) },
|
||||
# "third_person_cam": { "fovy": 0.7925, "fovx": 1.01, "res": (480, 640) },
|
||||
|
||||
})
|
||||
|
||||
self.splats = splats
|
||||
self.views = views
|
||||
self.splat_renderer = splat_renderer
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
|
||||
self.env_transformed = False
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
action = action["actions"]
|
||||
|
||||
# ur5e = self.file["observation/ur5e/joints/position"][self.step]
|
||||
# robotiq = self.file["observation/robotiq_gripper/gripper/position"][self.step]
|
||||
# action = np.concatenate([ur5e, robotiq], axis=-1)
|
||||
|
||||
# scale gripper from [0,1] to [-1,1]
|
||||
action = action.copy()
|
||||
action[-1] = action[-1] * 2 - 1
|
||||
|
||||
action = torch.tensor(action, dtype=torch.float32)[None]
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action)
|
||||
|
||||
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
# self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
img1 = self._last_obs["observation/base_0_camera/rgb/image"]
|
||||
img2 = self._last_obs["observation/wrist_0_camera/rgb/image"]
|
||||
big_img = np.concatenate([img1, img2], axis=1)
|
||||
cv2.imshow("big_img", cv2.cvtColor(big_img, cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
self.step += 1
|
||||
|
||||
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
# img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1))
|
||||
|
||||
|
||||
for splat in self.splats:
|
||||
if splat == "pi_scene_v2":
|
||||
if self.env_transformed:
|
||||
continue
|
||||
else:
|
||||
self.env_transformed = True
|
||||
if splat in self.views:
|
||||
view = self.views[splat]
|
||||
pos, rot = view.get_world_poses()
|
||||
pos, rot = pos.squeeze(), rot.squeeze()
|
||||
|
||||
|
||||
else:
|
||||
try:
|
||||
body = self._gym.scene[splat]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
pos = body.data.root_state_w[0, :3]
|
||||
rot = body.data.root_state_w[0, 3:7]
|
||||
|
||||
rot = math.matrix_from_quat(rot)
|
||||
self.splat_renderer.transform(
|
||||
pos,
|
||||
rot,
|
||||
scale_factor=1.0,
|
||||
obj = splat
|
||||
)
|
||||
|
||||
|
||||
cam_pos_hand = self._gym.scene["handcam"].data.pos_w[0].detach().cpu().numpy()
|
||||
cam_rot_hand = self._gym.scene["handcam"].data.quat_w_world[0].detach().cpu().numpy()
|
||||
cam_rot_hand = scalar_last(cam_rot_hand)
|
||||
cam_rot_hand = R.from_quat(cam_rot_hand).as_matrix()
|
||||
|
||||
cam_pos = self._gym.scene["camera"].data.pos_w[0].detach().cpu().numpy()
|
||||
cam_rot = self._gym.scene["camera"].data.quat_w_world[0].detach().cpu().numpy()
|
||||
cam_rot = scalar_last(cam_rot)
|
||||
cam_rot = R.from_quat(cam_rot).as_matrix()
|
||||
cam_extrinsics_dict = {
|
||||
"hand_cam": {
|
||||
"pos": cam_pos_hand,
|
||||
"rot": cam_rot_hand,
|
||||
},
|
||||
"third_person_cam": {
|
||||
"pos": cam_pos,
|
||||
"rot": cam_rot,
|
||||
},
|
||||
}
|
||||
rgb = self.splat_renderer.render(cam_extrinsics_dict)
|
||||
for k, v in rgb.items():
|
||||
rgb[k] = v.detach().cpu().numpy()
|
||||
rgb[k] = (rgb[k] * 255).astype(np.uint8)
|
||||
|
||||
data = {}
|
||||
data["observation/ur5e/joints/position"] = gym_obs["policy"]["joints"][:6].detach().cpu().numpy()
|
||||
data["observation/robotiq_gripper/gripper/position"] = gym_obs["policy"]["joints"][6:].detach().cpu().numpy()
|
||||
data["observation/base_0_camera/rgb/image"] = rgb["third_person_cam"]
|
||||
data["observation/wrist_0_camera/rgb/image"] = rgb["hand_cam"]
|
||||
|
||||
# data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_224_224"][self.step])
|
||||
# data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_224_224"][self.step])
|
||||
# data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_256_320"][self.step])
|
||||
# data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_256_320"][self.step])
|
||||
# data["observation/ur5e/joints/position"] = self.file["observation/ur5e/joints/position"][self.step]
|
||||
# data["observation/robotiq_gripper/gripper/position"] = self.file["observation/robotiq_gripper/gripper/position"][self.step]
|
||||
#
|
||||
# print(data["observation/ur5e/joints/position"])
|
||||
|
||||
return data
|
||||
|
||||
56
examples/ur_sim/main.py
Normal file
56
examples/ur_sim/main.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import env as _env
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import saver as _saver
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
out_path: pathlib.Path = pathlib.Path("replay.mp4")
|
||||
|
||||
task: str = "PIBussing"
|
||||
seed: int = 0
|
||||
|
||||
action_horizon: int = 10
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
display: bool = False
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.URSimEnvironment(
|
||||
task=args.task,
|
||||
seed=args.seed,
|
||||
),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[
|
||||
_saver.VideoSaver(args.out_path),
|
||||
],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
40
examples/ur_sim/saver.py
Normal file
40
examples/ur_sim/saver.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
# import openpi.transforms as transforms
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoSaver(_subscriber.Subscriber):
|
||||
"""Saves episode data."""
|
||||
|
||||
def __init__(self, out_path: pathlib.Path, subsample: int = 1) -> None:
|
||||
self._out_path = out_path
|
||||
self._images: list[np.ndarray] = []
|
||||
self._subsample = subsample
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
self._images = []
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
img1 = observation["observation/base_0_camera/rgb/image"]
|
||||
img2 = observation["observation/wrist_0_camera/rgb/image"]
|
||||
big_img = np.concatenate([img1, img2], axis=1)
|
||||
self._images.append(big_img)
|
||||
# im = observation["image"][0] # [C, H, W]
|
||||
# im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
# self._images.append(im)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
logging.info(f"Saving video to {self._out_path}")
|
||||
imageio.mimwrite(
|
||||
self._out_path,
|
||||
[np.asarray(x) for x in self._images[:: self._subsample]],
|
||||
fps=20 // max(1, self._subsample),
|
||||
)
|
||||
@@ -11,7 +11,7 @@ from openpi.models import exported as _exported
|
||||
from openpi.models import model as _model
|
||||
from openpi.policies import aloha_policy
|
||||
from openpi.policies import calvin_policy
|
||||
from openpi.policies import droid_policy
|
||||
from openpi.policies import droid_policy, ur_policy
|
||||
from openpi.policies import libero_policy
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
@@ -28,6 +28,7 @@ class EnvMode(enum.Enum):
|
||||
DROID = "droid"
|
||||
CALVIN = "calvin"
|
||||
LIBERO = "libero"
|
||||
UR = "ur"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -109,6 +110,10 @@ DEFAULT_EXPORTED: dict[EnvMode, Exported] = {
|
||||
dir="s3://openpi-assets/exported/pi0_libero/model",
|
||||
processor="libero",
|
||||
),
|
||||
EnvMode.UR: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_base/model",
|
||||
processor="ur5_single_24dim"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -222,9 +227,22 @@ def create_default_policy(
|
||||
libero_policy.LiberoOutputs(),
|
||||
],
|
||||
)
|
||||
case EnvMode.UR:
|
||||
delta_action_mask = delta_actions.make_bool_mask(6, -1)
|
||||
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
ur_policy.URInputs(action_dim=model.action_dim),
|
||||
transforms.ResizeImages(224,224),
|
||||
],
|
||||
output_layers=[
|
||||
ur_policy.UROutputs(
|
||||
delta_action_mask=delta_action_mask,
|
||||
)
|
||||
],
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown environment mode: {env}")
|
||||
|
||||
return _policy_config.create_policy(config)
|
||||
|
||||
|
||||
|
||||
58
src/openpi/policies/ur_policy.py
Normal file
58
src/openpi/policies/ur_policy.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
|
||||
|
||||
class URInputs(transforms.DataTransformFn):
|
||||
def __init__(self, action_dim: int, *, delta_action_mask: Sequence[bool] | None = None):
|
||||
self._action_dim = action_dim
|
||||
self._delta_action_mask = delta_action_mask
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
state = np.concatenate([
|
||||
data["observation/ur5e/joints/position"],
|
||||
data["observation/robotiq_gripper/gripper/position"]
|
||||
], axis=1)
|
||||
state = transforms.pad_to_dim(state, self._action_dim)
|
||||
print(f"state: {state}")
|
||||
|
||||
base_image = data["observation/base_0_camera/rgb/image"]
|
||||
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": data["observation/base_0_camera/rgb/image"],
|
||||
"left_wrist_0_rgb": data["observation/wrist_0_camera/rgb/image"],
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.ones(1, dtype=np.bool_),
|
||||
"left_wrist_0_rgb": np.ones(1, dtype=np.bool_),
|
||||
"right_wrist_0_rgb": np.zeros(1, dtype=np.bool_),
|
||||
},
|
||||
}
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class UROutputs(transforms.DataTransformFn):
|
||||
def __init__(self, *, delta_action_mask: Sequence[bool] | None = None):
|
||||
self._delta_action_mask = delta_action_mask
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 8 dims.
|
||||
actions = np.asarray(data["actions"][..., :7])
|
||||
|
||||
# Apply the delta action mask.
|
||||
if self._delta_action_mask is not None:
|
||||
state = np.asarray(data["state"][..., :7])
|
||||
mask = np.asarray(self._delta_action_mask[:7])
|
||||
actions = actions + np.expand_dims(np.where(mask, state, 0), axis=-2)
|
||||
|
||||
return {"actions": actions}
|
||||
Reference in New Issue
Block a user