Files
openpi/examples/ur_sim/env.py.back

229 lines
8.0 KiB
Plaintext

import argparse
import time
import sys
import logging
logging.getLogger('gymnasium').setLevel(logging.ERROR)
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
from omni.isaac.lab.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="Tutorial on using the differential IK controller.")
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli, other_args = parser.parse_known_args()
sys.argv = [sys.argv[0]] + other_args # clear out sys.argv for hydra
# launch omniverse app
args_cli.enable_cameras = True
args_cli.headless = True
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
"""Rest everything follows."""
import cv2
import h5py
import torch
import gymnasium
import numpy as np
from pathlib import Path
from openpi_client.runtime import environment as _environment
from typing_extensions import override
from scipy.spatial.transform import Rotation as R
import real2simeval.environments
from real2simeval.splat_render.render import SplatRenderer
from real2simeval.utils import get_transform_from_txt, scalar_last, decrease_brightness
from omni.isaac.lab_tasks.utils import parse_env_cfg
from omni.isaac.core.prims import GeometryPrimView
import omni.isaac.lab.utils.math as math
class URSimEnvironment(_environment.Environment):
"""An environment for an Aloha robot in simulation."""
def __init__(self, task: str, seed: int = 0) -> None:
np.random.seed(seed)
self._rng = np.random.default_rng(seed)
self.file = h5py.File("data/episode.h5", "r")
self.step = 0
env_cfg = parse_env_cfg(
task,
device= args_cli.device,
num_envs=1,
use_fabric=True,
)
self._gym = gymnasium.make(task, cfg = env_cfg)
splats = {
"pi_scene_v2": "./data/pi_scene_v2/splat.ply",
"bottle": "./data/pi_objects/bottle/splat.ply",
"plate": "./data/pi_objects/plate/splat.ply",
}
views = {}
robot = Path("./data/pi_robot/SEGMENTED/")
for ply in robot.glob("*.ply"):
splats[ply.stem] = str(ply)
path = ply.stem.replace("-", "/")
view = GeometryPrimView(
prim_paths_expr=f"/World/envs/env_.*/robot/{path}",
)
views[ply.stem] = view
splat_renderer = SplatRenderer(splats=splats)
splat_renderer.init_cameras({
"hand_cam": { "fovy": 1.04, "fovx": 1.33, "res": (480, 640) },
"third_person_cam": { "fovy": 1.04, "fovx": 1.33, "res": (480, 640) },
# "hand_cam": { "fovy": 0.7925, "fovx": 1.01, "res": (480, 640) },
# "third_person_cam": { "fovy": 0.7925, "fovx": 1.01, "res": (480, 640) },
})
self.splats = splats
self.views = views
self.splat_renderer = splat_renderer
self._last_obs = None
self._done = True
self._episode_reward = 0.0
@override
def reset(self) -> None:
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
self.env_transformed = False
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = False
self._episode_reward = 0.0
@override
def done(self) -> bool:
return self._done
@override
def get_observation(self) -> dict:
if self._last_obs is None:
raise RuntimeError("Observation is not set. Call reset() first.")
return self._last_obs # type: ignore
@override
def apply_action(self, action: dict) -> None:
action = action["actions"]
# ur5e = self.file["observation/ur5e/joints/position"][self.step]
# robotiq = self.file["observation/robotiq_gripper/gripper/position"][self.step]
# action = np.concatenate([ur5e, robotiq], axis=-1)
# scale gripper from [0,1] to [-1,1]
action = action.copy()
action[-1] = action[-1] * 2 - 1
action = torch.tensor(action, dtype=torch.float32)[None]
gym_obs, reward, terminated, truncated, info = self._gym.step(action)
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = terminated or truncated
# self._episode_reward = max(self._episode_reward, reward)
img1 = self._last_obs["observation/base_0_camera/rgb/image"]
img2 = self._last_obs["observation/wrist_0_camera/rgb/image"]
big_img = np.concatenate([img1, img2], axis=1)
cv2.imshow("big_img", cv2.cvtColor(big_img, cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
self.step += 1
def _convert_observation(self, gym_obs: dict) -> dict:
# Convert axis order from [H, W, C] --> [C, H, W]
# img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1))
for splat in self.splats:
if splat == "pi_scene_v2":
if self.env_transformed:
continue
else:
self.env_transformed = True
if splat in self.views:
view = self.views[splat]
pos, rot = view.get_world_poses()
pos, rot = pos.squeeze(), rot.squeeze()
else:
try:
body = self._gym.scene[splat]
except KeyError:
continue
pos = body.data.root_state_w[0, :3]
rot = body.data.root_state_w[0, 3:7]
rot = math.matrix_from_quat(rot)
self.splat_renderer.transform(
pos,
rot,
scale_factor=1.0,
obj = splat
)
cam_pos_hand = self._gym.scene["handcam"].data.pos_w[0].detach().cpu().numpy()
cam_rot_hand = self._gym.scene["handcam"].data.quat_w_world[0].detach().cpu().numpy()
cam_rot_hand = scalar_last(cam_rot_hand)
cam_rot_hand = R.from_quat(cam_rot_hand).as_matrix()
cam_pos = self._gym.scene["camera"].data.pos_w[0].detach().cpu().numpy()
cam_rot = self._gym.scene["camera"].data.quat_w_world[0].detach().cpu().numpy()
cam_rot = scalar_last(cam_rot)
cam_rot = R.from_quat(cam_rot).as_matrix()
cam_extrinsics_dict = {
"hand_cam": {
"pos": cam_pos_hand,
"rot": cam_rot_hand,
},
"third_person_cam": {
"pos": cam_pos,
"rot": cam_rot,
},
}
rgb = self.splat_renderer.render(cam_extrinsics_dict)
for k, v in rgb.items():
rgb[k] = v.detach().cpu().numpy()
rgb[k] = (rgb[k] * 255).astype(np.uint8)
data = {}
data["observation/ur5e/joints/position"] = gym_obs["policy"]["joints"][:6].detach().cpu().numpy()
data["observation/robotiq_gripper/gripper/position"] = gym_obs["policy"]["joints"][6:].detach().cpu().numpy()
data["observation/base_0_camera/rgb/image"] = rgb["third_person_cam"]
data["observation/wrist_0_camera/rgb/image"] = rgb["hand_cam"]
# data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_224_224"][self.step])
# data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_224_224"][self.step])
# data["observation/base_0_camera/rgb/image"] = (self.file["observation/base_0_camera/rgb/image_256_320"][self.step])
# data["observation/wrist_0_camera/rgb/image"] = (self.file["observation/wrist_0_camera/rgb/image_256_320"][self.step])
# data["observation/ur5e/joints/position"] = self.file["observation/ur5e/joints/position"][self.step]
# data["observation/robotiq_gripper/gripper/position"] = self.file["observation/robotiq_gripper/gripper/position"][self.step]
#
# print(data["observation/ur5e/joints/position"])
return data