diff --git a/gym_dora/gym_dora/env.py b/gym_dora/gym_dora/env.py index be94faa3..5316fc72 100644 --- a/gym_dora/gym_dora/env.py +++ b/gym_dora/gym_dora/env.py @@ -4,11 +4,53 @@ import gymnasium as gym import numpy as np import pyarrow as pa from dora import Node +from gymnasium import spaces IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640")) IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480")) FPS = int(os.getenv("FPS", "30")) +JOINTS = [ + # absolute joint position + "left_arm_waist", + "left_arm_shoulder", + "left_arm_elbow", + "left_arm_forearm_roll", + "left_arm_wrist_angle", + "left_arm_wrist_rotate", + # normalized gripper position 0: close, 1: open + "left_arm_gripper", + # absolute joint position + "right_arm_waist", + "right_arm_shoulder", + "right_arm_elbow", + "right_arm_forearm_roll", + "right_arm_wrist_angle", + "right_arm_wrist_rotate", + # normalized gripper position 0: close, 1: open + "right_arm_gripper", +] + +ACTIONS = [ + # position and quaternion for end effector + "left_arm_waist", + "left_arm_shoulder", + "left_arm_elbow", + "left_arm_forearm_roll", + "left_arm_wrist_angle", + "left_arm_wrist_rotate", + # normalized gripper position (0: close, 1: open) + "left_arm_gripper", + "right_arm_waist", + "right_arm_shoulder", + "right_arm_elbow", + "right_arm_forearm_roll", + "right_arm_wrist_angle", + "right_arm_wrist_rotate", + # normalized gripper position (0: close, 1: open) + "right_arm_gripper", +] + class DoraEnv(gym.Env): metadata = {"render_modes": ["rgb_array"], "render_fps": FPS} @@ -19,6 +61,51 @@ class DoraEnv(gym.Env): self.observation = {"pixels": {}, "agent_pos": None} self.terminated = False + self.observation_height = IMAGE_HEIGHT + self.observation_width = IMAGE_WIDTH + + assert model == "aloha" + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict( + { + "cam_high": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + "cam_low": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + "cam_left_wrist": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + "cam_right_wrist": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + } + ), + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(len(JOINTS),), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-1, high=1, shape=(len(ACTIONS),), dtype=np.float32) + def _get_obs(self): while True: event = self.node.next(timeout=0.001)