From 1e6b7d249bcea36a8eea9fce7ae2ce28aede0863 Mon Sep 17 00:00:00 2001 From: haixuantao Date: Tue, 21 May 2024 14:16:29 +0200 Subject: [PATCH] Add some documentation --- gym_dora/gym_dora/env.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/gym_dora/gym_dora/env.py b/gym_dora/gym_dora/env.py index d0b406e3b..664a76d58 100644 --- a/gym_dora/gym_dora/env.py +++ b/gym_dora/gym_dora/env.py @@ -12,6 +12,7 @@ class DoraEnv(gym.Env): metadata = {"render_modes": ["rgb_array"], "render_fps": FPS} def __init__(self, model="aloha"): + # Initialize a new node self.node = Node() self.observation = {"pixels": {}, "terminated": False} @@ -19,14 +20,20 @@ class DoraEnv(gym.Env): while True: event = self.node.next(timeout=0.001) + ## If event is None, the node event stream is closed and we should terminate the env if event is None: self.observation["terminated"] = True break + if event["type"] == "INPUT": + # Map Image input into pixels key within Aloha environment if "cam" in event["id"]: self.observation["pixels"][event["id"]] = event["value"].to_numpy().reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3) else: + # Map other inputs into the observation dictionary using the event id as key self.observation[event["id"]] = event["value"].to_numpy() + + # If the event is a timeout error break the update loop. elif event["type"] == "ERROR": break @@ -44,6 +51,8 @@ class DoraEnv(gym.Env): def step(self, action: np.ndarray): self._update() + + # Send the action to the dataflow as action key. self.node.send_output("action", pa.array(action)) reward = 0 terminated = truncated = self.observation["terminated"] @@ -51,5 +60,7 @@ class DoraEnv(gym.Env): return self.observation, reward, terminated, truncated, info def close(self): + + # Drop the node del self.node