Add some documentation
This commit is contained in:
@@ -12,6 +12,7 @@ class DoraEnv(gym.Env):
|
|||||||
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
|
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
|
||||||
|
|
||||||
def __init__(self, model="aloha"):
|
def __init__(self, model="aloha"):
|
||||||
|
# Initialize a new node
|
||||||
self.node = Node()
|
self.node = Node()
|
||||||
self.observation = {"pixels": {}, "terminated": False}
|
self.observation = {"pixels": {}, "terminated": False}
|
||||||
|
|
||||||
@@ -19,14 +20,20 @@ class DoraEnv(gym.Env):
|
|||||||
while True:
|
while True:
|
||||||
event = self.node.next(timeout=0.001)
|
event = self.node.next(timeout=0.001)
|
||||||
|
|
||||||
|
## If event is None, the node event stream is closed and we should terminate the env
|
||||||
if event is None:
|
if event is None:
|
||||||
self.observation["terminated"] = True
|
self.observation["terminated"] = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if event["type"] == "INPUT":
|
if event["type"] == "INPUT":
|
||||||
|
# Map Image input into pixels key within Aloha environment
|
||||||
if "cam" in event["id"]:
|
if "cam" in event["id"]:
|
||||||
self.observation["pixels"][event["id"]] = event["value"].to_numpy().reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3)
|
self.observation["pixels"][event["id"]] = event["value"].to_numpy().reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3)
|
||||||
else:
|
else:
|
||||||
|
# Map other inputs into the observation dictionary using the event id as key
|
||||||
self.observation[event["id"]] = event["value"].to_numpy()
|
self.observation[event["id"]] = event["value"].to_numpy()
|
||||||
|
|
||||||
|
# If the event is a timeout error break the update loop.
|
||||||
elif event["type"] == "ERROR":
|
elif event["type"] == "ERROR":
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -44,6 +51,8 @@ class DoraEnv(gym.Env):
|
|||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
self._update()
|
self._update()
|
||||||
|
|
||||||
|
# Send the action to the dataflow as action key.
|
||||||
self.node.send_output("action", pa.array(action))
|
self.node.send_output("action", pa.array(action))
|
||||||
reward = 0
|
reward = 0
|
||||||
terminated = truncated = self.observation["terminated"]
|
terminated = truncated = self.observation["terminated"]
|
||||||
@@ -51,5 +60,7 @@ class DoraEnv(gym.Env):
|
|||||||
return self.observation, reward, terminated, truncated, info
|
return self.observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
||||||
|
# Drop the node
|
||||||
del self.node
|
del self.node
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user