proposal for a more general Dora env

This commit is contained in:
Thomas Wolf
2024-05-29 15:29:41 +02:00
parent 68a680a9eb
commit e2f690e779

View File

@@ -6,11 +6,11 @@ import pyarrow as pa
from dora import Node from dora import Node
from gymnasium import spaces from gymnasium import spaces
FPS = int(os.getenv("FPS", "30"))
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640")) IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480")) IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
FPS = int(os.getenv("FPS", "30"))
JOINTS = [ ALOHA_JOINTS = [
# absolute joint position # absolute joint position
"left_arm_waist", "left_arm_waist",
"left_arm_shoulder", "left_arm_shoulder",
@@ -30,8 +30,7 @@ JOINTS = [
# normalized gripper position 0: close, 1: open # normalized gripper position 0: close, 1: open
"right_arm_gripper", "right_arm_gripper",
] ]
ALOHA_ACTIONS = [
ACTIONS = [
# position and quaternion for end effector # position and quaternion for end effector
"left_arm_waist", "left_arm_waist",
"left_arm_shoulder", "left_arm_shoulder",
@@ -55,56 +54,104 @@ ACTIONS = [
class DoraEnv(gym.Env): class DoraEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS} metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
def __init__(self, model="aloha"): def __init__(
self,
model="aloha",
observation_width=IMAGE_WIDTH,
observation_height=IMAGE_HEIGHT,
cameras_names=None,
num_joints=None,
num_actions=None,
):
"""Initializes the Dora environment.
Args:
model (str): The model to use. Either 'aloha' or 'custom'.
observation_width (int): The width of the observation image.
observation_height (int): The height of the observation image.
cameras_names (list): A list of camera names to use. If not provided, the default is ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'].
num_joints (int): The number of joints in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
num_actions (int): The number of actions in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
"""
super().__init__()
# Initialize a new node # Initialize a new node
self.node = Node() self.node = Node() if os.environ.get("DORA_NODE_CONFIG", None) is not None else None
self.observation = {"pixels": {}, "agent_pos": None} self.observation = {"pixels": {}, "agent_pos": None}
self.terminated = False self.terminated = False
self.observation_height = IMAGE_HEIGHT self.observation_height = observation_height
self.observation_width = IMAGE_WIDTH self.observation_width = observation_width
assert model == "aloha" # Observation space
self.observation_space = spaces.Dict( if model == "aloha":
{ self.observation_space = spaces.Dict(
"pixels": spaces.Dict( {
{ "pixels": spaces.Dict(
"cam_high": spaces.Box( {
low=0, "cam_high": spaces.Box(
high=255, low=0,
shape=(self.observation_height, self.observation_width, 3), high=255,
dtype=np.uint8, shape=(self.observation_height, self.observation_width, 3),
), dtype=np.uint8,
"cam_low": spaces.Box( ),
low=0, "cam_low": spaces.Box(
high=255, low=0,
shape=(self.observation_height, self.observation_width, 3), high=255,
dtype=np.uint8, shape=(self.observation_height, self.observation_width, 3),
), dtype=np.uint8,
"cam_left_wrist": spaces.Box( ),
low=0, "cam_left_wrist": spaces.Box(
high=255, low=0,
shape=(self.observation_height, self.observation_width, 3), high=255,
dtype=np.uint8, shape=(self.observation_height, self.observation_width, 3),
), dtype=np.uint8,
"cam_right_wrist": spaces.Box( ),
low=0, "cam_right_wrist": spaces.Box(
high=255, low=0,
shape=(self.observation_height, self.observation_width, 3), high=255,
dtype=np.uint8, shape=(self.observation_height, self.observation_width, 3),
), dtype=np.uint8,
} ),
), }
"agent_pos": spaces.Box( ),
low=-1000.0, "agent_pos": spaces.Box(
high=1000.0, low=-1000.0,
shape=(len(JOINTS),), high=1000.0,
dtype=np.float64, shape=(len(ALOHA_JOINTS),),
), dtype=np.float64,
} ),
) }
)
elif model == "custom":
pixel_dict = {}
for camera in cameras_names:
assert camera.startswith("cam"), "Camera names must start with 'cam'"
pixel_dict[camera] = spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(pixel_dict),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
),
}
)
else:
raise ValueError("Model must be either 'aloha' or 'custom'.")
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ACTIONS),), dtype=np.float32) # Action space
if model == "aloha":
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ALOHA_ACTIONS),), dtype=np.float32)
elif model == "custom":
self.action_space = spaces.Box(low=-1, high=1, shape=(num_actions,), dtype=np.float32)
def _get_obs(self): def _get_obs(self):
while True: while True:
@@ -119,7 +166,7 @@ class DoraEnv(gym.Env):
# Map Image input into pixels key within Aloha environment # Map Image input into pixels key within Aloha environment
if "cam" in event["id"]: if "cam" in event["id"]:
self.observation["pixels"][event["id"]] = ( self.observation["pixels"][event["id"]] = (
event["value"].to_numpy().reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3) event["value"].to_numpy().reshape(self.observation_height, self.observation_width, 3)
) )
else: else:
# Map other inputs into the observation dictionary using the event id as key # Map other inputs into the observation dictionary using the event id as key