proposal for a more general Dora env
This commit is contained in:
@@ -6,11 +6,11 @@ import pyarrow as pa
|
|||||||
from dora import Node
|
from dora import Node
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
FPS = int(os.getenv("FPS", "30"))
|
||||||
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
|
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
|
||||||
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
|
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
|
||||||
FPS = int(os.getenv("FPS", "30"))
|
|
||||||
|
|
||||||
JOINTS = [
|
ALOHA_JOINTS = [
|
||||||
# absolute joint position
|
# absolute joint position
|
||||||
"left_arm_waist",
|
"left_arm_waist",
|
||||||
"left_arm_shoulder",
|
"left_arm_shoulder",
|
||||||
@@ -30,8 +30,7 @@ JOINTS = [
|
|||||||
# normalized gripper position 0: close, 1: open
|
# normalized gripper position 0: close, 1: open
|
||||||
"right_arm_gripper",
|
"right_arm_gripper",
|
||||||
]
|
]
|
||||||
|
ALOHA_ACTIONS = [
|
||||||
ACTIONS = [
|
|
||||||
# position and quaternion for end effector
|
# position and quaternion for end effector
|
||||||
"left_arm_waist",
|
"left_arm_waist",
|
||||||
"left_arm_shoulder",
|
"left_arm_shoulder",
|
||||||
@@ -55,56 +54,104 @@ ACTIONS = [
|
|||||||
class DoraEnv(gym.Env):
|
class DoraEnv(gym.Env):
|
||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
|
||||||
|
|
||||||
def __init__(self, model="aloha"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model="aloha",
|
||||||
|
observation_width=IMAGE_WIDTH,
|
||||||
|
observation_height=IMAGE_HEIGHT,
|
||||||
|
cameras_names=None,
|
||||||
|
num_joints=None,
|
||||||
|
num_actions=None,
|
||||||
|
):
|
||||||
|
"""Initializes the Dora environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The model to use. Either 'aloha' or 'custom'.
|
||||||
|
observation_width (int): The width of the observation image.
|
||||||
|
observation_height (int): The height of the observation image.
|
||||||
|
cameras_names (list): A list of camera names to use. If not provided, the default is ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'].
|
||||||
|
num_joints (int): The number of joints in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
|
||||||
|
num_actions (int): The number of actions in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
# Initialize a new node
|
# Initialize a new node
|
||||||
self.node = Node()
|
self.node = Node() if os.environ.get("DORA_NODE_CONFIG", None) is not None else None
|
||||||
self.observation = {"pixels": {}, "agent_pos": None}
|
self.observation = {"pixels": {}, "agent_pos": None}
|
||||||
self.terminated = False
|
self.terminated = False
|
||||||
|
|
||||||
self.observation_height = IMAGE_HEIGHT
|
self.observation_height = observation_height
|
||||||
self.observation_width = IMAGE_WIDTH
|
self.observation_width = observation_width
|
||||||
|
|
||||||
assert model == "aloha"
|
# Observation space
|
||||||
self.observation_space = spaces.Dict(
|
if model == "aloha":
|
||||||
{
|
self.observation_space = spaces.Dict(
|
||||||
"pixels": spaces.Dict(
|
{
|
||||||
{
|
"pixels": spaces.Dict(
|
||||||
"cam_high": spaces.Box(
|
{
|
||||||
low=0,
|
"cam_high": spaces.Box(
|
||||||
high=255,
|
low=0,
|
||||||
shape=(self.observation_height, self.observation_width, 3),
|
high=255,
|
||||||
dtype=np.uint8,
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
),
|
dtype=np.uint8,
|
||||||
"cam_low": spaces.Box(
|
),
|
||||||
low=0,
|
"cam_low": spaces.Box(
|
||||||
high=255,
|
low=0,
|
||||||
shape=(self.observation_height, self.observation_width, 3),
|
high=255,
|
||||||
dtype=np.uint8,
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
),
|
dtype=np.uint8,
|
||||||
"cam_left_wrist": spaces.Box(
|
),
|
||||||
low=0,
|
"cam_left_wrist": spaces.Box(
|
||||||
high=255,
|
low=0,
|
||||||
shape=(self.observation_height, self.observation_width, 3),
|
high=255,
|
||||||
dtype=np.uint8,
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
),
|
dtype=np.uint8,
|
||||||
"cam_right_wrist": spaces.Box(
|
),
|
||||||
low=0,
|
"cam_right_wrist": spaces.Box(
|
||||||
high=255,
|
low=0,
|
||||||
shape=(self.observation_height, self.observation_width, 3),
|
high=255,
|
||||||
dtype=np.uint8,
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
),
|
dtype=np.uint8,
|
||||||
}
|
),
|
||||||
),
|
}
|
||||||
"agent_pos": spaces.Box(
|
),
|
||||||
low=-1000.0,
|
"agent_pos": spaces.Box(
|
||||||
high=1000.0,
|
low=-1000.0,
|
||||||
shape=(len(JOINTS),),
|
high=1000.0,
|
||||||
dtype=np.float64,
|
shape=(len(ALOHA_JOINTS),),
|
||||||
),
|
dtype=np.float64,
|
||||||
}
|
),
|
||||||
)
|
}
|
||||||
|
)
|
||||||
|
elif model == "custom":
|
||||||
|
pixel_dict = {}
|
||||||
|
for camera in cameras_names:
|
||||||
|
assert camera.startswith("cam"), "Camera names must start with 'cam'"
|
||||||
|
pixel_dict[camera] = spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Dict(pixel_dict),
|
||||||
|
"agent_pos": spaces.Box(
|
||||||
|
low=-1000.0,
|
||||||
|
high=1000.0,
|
||||||
|
shape=(num_joints,),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Model must be either 'aloha' or 'custom'.")
|
||||||
|
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ACTIONS),), dtype=np.float32)
|
# Action space
|
||||||
|
if model == "aloha":
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ALOHA_ACTIONS),), dtype=np.float32)
|
||||||
|
elif model == "custom":
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(num_actions,), dtype=np.float32)
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -119,7 +166,7 @@ class DoraEnv(gym.Env):
|
|||||||
# Map Image input into pixels key within Aloha environment
|
# Map Image input into pixels key within Aloha environment
|
||||||
if "cam" in event["id"]:
|
if "cam" in event["id"]:
|
||||||
self.observation["pixels"][event["id"]] = (
|
self.observation["pixels"][event["id"]] = (
|
||||||
event["value"].to_numpy().reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3)
|
event["value"].to_numpy().reshape(self.observation_height, self.observation_width, 3)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Map other inputs into the observation dictionary using the event id as key
|
# Map other inputs into the observation dictionary using the event id as key
|
||||||
|
|||||||
Reference in New Issue
Block a user