diff --git a/examples/ur_sim/env.py b/examples/ur_sim/env.py new file mode 100644 index 0000000..ffd36c3 --- /dev/null +++ b/examples/ur_sim/env.py @@ -0,0 +1,154 @@ +import argparse +import time +import sys +import logging +logging.getLogger('gymnasium').setLevel(logging.ERROR) + + +import warnings +warnings.filterwarnings('ignore', category=UserWarning) + +from omni.isaac.lab.app import AppLauncher + +# add argparse arguments +parser = argparse.ArgumentParser(description="Tutorial on using the differential IK controller.") +# append AppLauncher cli args +AppLauncher.add_app_launcher_args(parser) +# parse the arguments + +args_cli, other_args = parser.parse_known_args() +sys.argv = [sys.argv[0]] + other_args # clear out sys.argv for hydra + + + +# launch omniverse app +args_cli.enable_cameras = True +# args_cli.headless = True +args_cli.headless = False +app_launcher = AppLauncher(args_cli) +simulation_app = app_launcher.app + +"""Rest everything follows.""" + +import cv2 +import h5py +import torch + +import gymnasium +import numpy as np +from pathlib import Path +from openpi_client.runtime import environment as _environment +from typing_extensions import override +from scipy.spatial.transform import Rotation as R +import real2simeval.environments +from real2simeval.splat_render.render import SplatRenderer +from real2simeval.utils import get_transform_from_txt, scalar_last, decrease_brightness + +from omni.isaac.lab_tasks.utils import parse_env_cfg +from omni.isaac.core.prims import GeometryPrimView +import omni.isaac.lab.utils.math as math + + +DATA_PATH = Path(__file__).parent.parent.parent.parent.parent / "data" + +class URSimEnvironment(_environment.Environment): + """An environment for an Aloha robot in simulation.""" + + def __init__(self, task: str, seed: int = 0) -> None: + np.random.seed(seed) + self._rng = np.random.default_rng(seed) + + self.file = h5py.File("data/episode.h5", "r") + self.step = 0 + + env_cfg = parse_env_cfg( + task, + device= args_cli.device, + num_envs=1, + use_fabric=True, + ) + + sim_assets = { + "pi_scene_v2_static": DATA_PATH/"pi_scene_v2", + "bottle": DATA_PATH/"pi_objects/bottle", + "plate": DATA_PATH/"pi_objects/plate", + "robot": DATA_PATH/"pi_robot/", + } + env_cfg.setup_scene(sim_assets) + + + self._gym = gymnasium.make(task, cfg = env_cfg) + + self._last_obs = None + self._done = True + self._episode_reward = 0.0 + + @override + def reset(self) -> None: + gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1))) + + self._last_obs = self._convert_observation(gym_obs) # type: ignore + self._done = False + self._episode_reward = 0.0 + + + @override + def done(self) -> bool: + return self._done + + @override + def get_observation(self) -> dict: + if self._last_obs is None: + raise RuntimeError("Observation is not set. Call reset() first.") + + return self._last_obs # type: ignore + + @override + def apply_action(self, action: dict) -> None: + action = action["actions"] + + # ur5e = self.file["observation/ur5e/joints/position"][self.step] + # robotiq = self.file["observation/robotiq_gripper/gripper/position"][self.step] + # action = np.concatenate([ur5e, robotiq], axis=-1) + + # scale gripper from [0,1] to [-1,1] + action = action.copy() + action[-1] = action[-1] * 2 - 1 + + action = torch.tensor(action, dtype=torch.float32)[None] + gym_obs, reward, terminated, truncated, info = self._gym.step(action) + + + self._last_obs = self._convert_observation(gym_obs) # type: ignore + self._done = terminated or truncated + # self._episode_reward = max(self._episode_reward, reward) + + img1 = self._last_obs["observation/base_0_camera/rgb/image"] + img2 = self._last_obs["observation/wrist_0_camera/rgb/image"] + big_img = np.concatenate([img1, img2], axis=1) + cv2.imshow("big_img", cv2.cvtColor(big_img, cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + self.step += 1 + + + + def _convert_observation(self, gym_obs: dict) -> dict: + # Convert axis order from [H, W, C] --> [C, H, W] + # img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1)) + data = {} + data["observation/ur5e/joints/position"] = gym_obs["policy"]["joints"][:6].detach().cpu().numpy() + data["observation/robotiq_gripper/gripper/position"] = gym_obs["policy"]["joints"][6:].detach().cpu().numpy() + data["observation/base_0_camera/rgb/image"] = gym_obs["splat"]["base_cam"] + data["observation/wrist_0_camera/rgb/image"] = gym_obs["splat"]["wrist_cam"] + + # data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_224_224"][self.step]) + # data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_224_224"][self.step]) + # data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_256_320"][self.step]) + # data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_256_320"][self.step]) + # data["observation/ur5e/joints/position"] = self.file["observation/ur5e/joints/position"][self.step] + # data["observation/robotiq_gripper/gripper/position"] = self.file["observation/robotiq_gripper/gripper/position"][self.step] + # + # print(data["observation/ur5e/joints/position"]) + + return data + diff --git a/examples/ur_sim/env.py.back b/examples/ur_sim/env.py.back new file mode 100644 index 0000000..9ce3902 --- /dev/null +++ b/examples/ur_sim/env.py.back @@ -0,0 +1,228 @@ +import argparse +import time +import sys +import logging +logging.getLogger('gymnasium').setLevel(logging.ERROR) + + +import warnings +warnings.filterwarnings('ignore', category=UserWarning) + +from omni.isaac.lab.app import AppLauncher + +# add argparse arguments +parser = argparse.ArgumentParser(description="Tutorial on using the differential IK controller.") +# append AppLauncher cli args +AppLauncher.add_app_launcher_args(parser) +# parse the arguments + +args_cli, other_args = parser.parse_known_args() +sys.argv = [sys.argv[0]] + other_args # clear out sys.argv for hydra + + + +# launch omniverse app +args_cli.enable_cameras = True +args_cli.headless = True +app_launcher = AppLauncher(args_cli) +simulation_app = app_launcher.app + +"""Rest everything follows.""" + +import cv2 +import h5py +import torch +import gymnasium +import numpy as np +from pathlib import Path +from openpi_client.runtime import environment as _environment +from typing_extensions import override +from scipy.spatial.transform import Rotation as R +import real2simeval.environments +from real2simeval.splat_render.render import SplatRenderer +from real2simeval.utils import get_transform_from_txt, scalar_last, decrease_brightness + +from omni.isaac.lab_tasks.utils import parse_env_cfg +from omni.isaac.core.prims import GeometryPrimView +import omni.isaac.lab.utils.math as math + + +class URSimEnvironment(_environment.Environment): + """An environment for an Aloha robot in simulation.""" + + def __init__(self, task: str, seed: int = 0) -> None: + np.random.seed(seed) + self._rng = np.random.default_rng(seed) + + self.file = h5py.File("data/episode.h5", "r") + self.step = 0 + + env_cfg = parse_env_cfg( + task, + device= args_cli.device, + num_envs=1, + use_fabric=True, + ) + self._gym = gymnasium.make(task, cfg = env_cfg) + + + splats = { + "pi_scene_v2": "./data/pi_scene_v2/splat.ply", + "bottle": "./data/pi_objects/bottle/splat.ply", + "plate": "./data/pi_objects/plate/splat.ply", + } + views = {} + robot = Path("./data/pi_robot/SEGMENTED/") + for ply in robot.glob("*.ply"): + splats[ply.stem] = str(ply) + path = ply.stem.replace("-", "/") + view = GeometryPrimView( + prim_paths_expr=f"/World/envs/env_.*/robot/{path}", + ) + views[ply.stem] = view + + + splat_renderer = SplatRenderer(splats=splats) + + splat_renderer.init_cameras({ + "hand_cam": { "fovy": 1.04, "fovx": 1.33, "res": (480, 640) }, + "third_person_cam": { "fovy": 1.04, "fovx": 1.33, "res": (480, 640) }, + # "hand_cam": { "fovy": 0.7925, "fovx": 1.01, "res": (480, 640) }, + # "third_person_cam": { "fovy": 0.7925, "fovx": 1.01, "res": (480, 640) }, + + }) + + self.splats = splats + self.views = views + self.splat_renderer = splat_renderer + + self._last_obs = None + self._done = True + self._episode_reward = 0.0 + + @override + def reset(self) -> None: + gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1))) + + self.env_transformed = False + self._last_obs = self._convert_observation(gym_obs) # type: ignore + self._done = False + self._episode_reward = 0.0 + + + @override + def done(self) -> bool: + return self._done + + @override + def get_observation(self) -> dict: + if self._last_obs is None: + raise RuntimeError("Observation is not set. Call reset() first.") + + return self._last_obs # type: ignore + + @override + def apply_action(self, action: dict) -> None: + action = action["actions"] + + # ur5e = self.file["observation/ur5e/joints/position"][self.step] + # robotiq = self.file["observation/robotiq_gripper/gripper/position"][self.step] + # action = np.concatenate([ur5e, robotiq], axis=-1) + + # scale gripper from [0,1] to [-1,1] + action = action.copy() + action[-1] = action[-1] * 2 - 1 + + action = torch.tensor(action, dtype=torch.float32)[None] + gym_obs, reward, terminated, truncated, info = self._gym.step(action) + + + self._last_obs = self._convert_observation(gym_obs) # type: ignore + self._done = terminated or truncated + # self._episode_reward = max(self._episode_reward, reward) + + img1 = self._last_obs["observation/base_0_camera/rgb/image"] + img2 = self._last_obs["observation/wrist_0_camera/rgb/image"] + big_img = np.concatenate([img1, img2], axis=1) + cv2.imshow("big_img", cv2.cvtColor(big_img, cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + self.step += 1 + + + + def _convert_observation(self, gym_obs: dict) -> dict: + # Convert axis order from [H, W, C] --> [C, H, W] + # img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1)) + + + for splat in self.splats: + if splat == "pi_scene_v2": + if self.env_transformed: + continue + else: + self.env_transformed = True + if splat in self.views: + view = self.views[splat] + pos, rot = view.get_world_poses() + pos, rot = pos.squeeze(), rot.squeeze() + + + else: + try: + body = self._gym.scene[splat] + except KeyError: + continue + + pos = body.data.root_state_w[0, :3] + rot = body.data.root_state_w[0, 3:7] + + rot = math.matrix_from_quat(rot) + self.splat_renderer.transform( + pos, + rot, + scale_factor=1.0, + obj = splat + ) + + + cam_pos_hand = self._gym.scene["handcam"].data.pos_w[0].detach().cpu().numpy() + cam_rot_hand = self._gym.scene["handcam"].data.quat_w_world[0].detach().cpu().numpy() + cam_rot_hand = scalar_last(cam_rot_hand) + cam_rot_hand = R.from_quat(cam_rot_hand).as_matrix() + + cam_pos = self._gym.scene["camera"].data.pos_w[0].detach().cpu().numpy() + cam_rot = self._gym.scene["camera"].data.quat_w_world[0].detach().cpu().numpy() + cam_rot = scalar_last(cam_rot) + cam_rot = R.from_quat(cam_rot).as_matrix() + cam_extrinsics_dict = { + "hand_cam": { + "pos": cam_pos_hand, + "rot": cam_rot_hand, + }, + "third_person_cam": { + "pos": cam_pos, + "rot": cam_rot, + }, + } + rgb = self.splat_renderer.render(cam_extrinsics_dict) + for k, v in rgb.items(): + rgb[k] = v.detach().cpu().numpy() + rgb[k] = (rgb[k] * 255).astype(np.uint8) + + data = {} + data["observation/ur5e/joints/position"] = gym_obs["policy"]["joints"][:6].detach().cpu().numpy() + data["observation/robotiq_gripper/gripper/position"] = gym_obs["policy"]["joints"][6:].detach().cpu().numpy() + data["observation/base_0_camera/rgb/image"] = rgb["third_person_cam"] + data["observation/wrist_0_camera/rgb/image"] = rgb["hand_cam"] + + # data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_224_224"][self.step]) + # data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_224_224"][self.step]) + # data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_256_320"][self.step]) + # data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_256_320"][self.step]) + # data["observation/ur5e/joints/position"] = self.file["observation/ur5e/joints/position"][self.step] + # data["observation/robotiq_gripper/gripper/position"] = self.file["observation/robotiq_gripper/gripper/position"][self.step] + # + # print(data["observation/ur5e/joints/position"]) + + return data + diff --git a/examples/ur_sim/main.py b/examples/ur_sim/main.py new file mode 100644 index 0000000..97a4d9b --- /dev/null +++ b/examples/ur_sim/main.py @@ -0,0 +1,56 @@ + +import dataclasses +import logging +import pathlib + +import env as _env +from openpi_client import action_chunk_broker +from openpi_client import websocket_client_policy as _websocket_client_policy +from openpi_client.runtime import runtime as _runtime +from openpi_client.runtime.agents import policy_agent as _policy_agent +import saver as _saver +import tyro + + +@dataclasses.dataclass +class Args: + out_path: pathlib.Path = pathlib.Path("replay.mp4") + + task: str = "PIBussing" + seed: int = 0 + + action_horizon: int = 10 + + host: str = "0.0.0.0" + port: int = 8000 + + display: bool = False + + +def main(args: Args) -> None: + runtime = _runtime.Runtime( + environment=_env.URSimEnvironment( + task=args.task, + seed=args.seed, + ), + agent=_policy_agent.PolicyAgent( + policy=action_chunk_broker.ActionChunkBroker( + policy=_websocket_client_policy.WebsocketClientPolicy( + host=args.host, + port=args.port, + ), + action_horizon=args.action_horizon, + ) + ), + subscribers=[ + _saver.VideoSaver(args.out_path), + ], + max_hz=50, + ) + + runtime.run() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, force=True) + tyro.cli(main) diff --git a/examples/ur_sim/saver.py b/examples/ur_sim/saver.py new file mode 100644 index 0000000..be8fa1a --- /dev/null +++ b/examples/ur_sim/saver.py @@ -0,0 +1,40 @@ +import logging +import pathlib + +import imageio +import numpy as np +from openpi_client.runtime import subscriber as _subscriber +# import openpi.transforms as transforms +from typing_extensions import override + + +class VideoSaver(_subscriber.Subscriber): + """Saves episode data.""" + + def __init__(self, out_path: pathlib.Path, subsample: int = 1) -> None: + self._out_path = out_path + self._images: list[np.ndarray] = [] + self._subsample = subsample + + @override + def on_episode_start(self) -> None: + self._images = [] + + @override + def on_step(self, observation: dict, action: dict) -> None: + img1 = observation["observation/base_0_camera/rgb/image"] + img2 = observation["observation/wrist_0_camera/rgb/image"] + big_img = np.concatenate([img1, img2], axis=1) + self._images.append(big_img) + # im = observation["image"][0] # [C, H, W] + # im = np.transpose(im, (1, 2, 0)) # [H, W, C] + # self._images.append(im) + + @override + def on_episode_end(self) -> None: + logging.info(f"Saving video to {self._out_path}") + imageio.mimwrite( + self._out_path, + [np.asarray(x) for x in self._images[:: self._subsample]], + fps=20 // max(1, self._subsample), + ) diff --git a/scripts/serve_policy.py b/scripts/serve_policy.py index 55d921b..c990a54 100644 --- a/scripts/serve_policy.py +++ b/scripts/serve_policy.py @@ -11,7 +11,7 @@ from openpi.models import exported as _exported from openpi.models import model as _model from openpi.policies import aloha_policy from openpi.policies import calvin_policy -from openpi.policies import droid_policy +from openpi.policies import droid_policy, ur_policy from openpi.policies import libero_policy from openpi.policies import policy as _policy from openpi.policies import policy_config as _policy_config @@ -28,6 +28,7 @@ class EnvMode(enum.Enum): DROID = "droid" CALVIN = "calvin" LIBERO = "libero" + UR = "ur" @dataclasses.dataclass @@ -109,6 +110,10 @@ DEFAULT_EXPORTED: dict[EnvMode, Exported] = { dir="s3://openpi-assets/exported/pi0_libero/model", processor="libero", ), + EnvMode.UR: Exported( + dir="s3://openpi-assets/exported/pi0_base/model", + processor="ur5_single_24dim" + ) } @@ -222,9 +227,22 @@ def create_default_policy( libero_policy.LiberoOutputs(), ], ) + case EnvMode.UR: + delta_action_mask = delta_actions.make_bool_mask(6, -1) + + config = make_policy_config( + input_layers=[ + ur_policy.URInputs(action_dim=model.action_dim), + transforms.ResizeImages(224,224), + ], + output_layers=[ + ur_policy.UROutputs( + delta_action_mask=delta_action_mask, + ) + ], + ) case _: raise ValueError(f"Unknown environment mode: {env}") - return _policy_config.create_policy(config) diff --git a/src/openpi/policies/ur_policy.py b/src/openpi/policies/ur_policy.py new file mode 100644 index 0000000..4e78ba0 --- /dev/null +++ b/src/openpi/policies/ur_policy.py @@ -0,0 +1,58 @@ +from collections.abc import Sequence + +import numpy as np + +from openpi import transforms + + +class URInputs(transforms.DataTransformFn): + def __init__(self, action_dim: int, *, delta_action_mask: Sequence[bool] | None = None): + self._action_dim = action_dim + self._delta_action_mask = delta_action_mask + + def __call__(self, data: dict) -> dict: + state = np.concatenate([ + data["observation/ur5e/joints/position"], + data["observation/robotiq_gripper/gripper/position"] + ], axis=1) + state = transforms.pad_to_dim(state, self._action_dim) + print(f"state: {state}") + + base_image = data["observation/base_0_camera/rgb/image"] + + inputs = { + "state": state, + "image": { + "base_0_rgb": data["observation/base_0_camera/rgb/image"], + "left_wrist_0_rgb": data["observation/wrist_0_camera/rgb/image"], + "right_wrist_0_rgb": np.zeros_like(base_image), + }, + "image_mask": { + "base_0_rgb": np.ones(1, dtype=np.bool_), + "left_wrist_0_rgb": np.ones(1, dtype=np.bool_), + "right_wrist_0_rgb": np.zeros(1, dtype=np.bool_), + }, + } + + if "prompt" in data: + inputs["prompt"] = data["prompt"] + + + return inputs + + +class UROutputs(transforms.DataTransformFn): + def __init__(self, *, delta_action_mask: Sequence[bool] | None = None): + self._delta_action_mask = delta_action_mask + + def __call__(self, data: dict) -> dict: + # Only return the first 8 dims. + actions = np.asarray(data["actions"][..., :7]) + + # Apply the delta action mask. + if self._delta_action_mask is not None: + state = np.asarray(data["state"][..., :7]) + mask = np.asarray(self._delta_action_mask[:7]) + actions = actions + np.expand_dims(np.where(mask, state, 0), axis=-2) + + return {"actions": actions}