From d0cd39f9b5a891885034d3179f66e43222a5f9c4 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 31 Mar 2024 18:26:32 +0200 Subject: [PATCH] Fix tests (when from_pixel=false) --- lerobot/common/envs/simxarm/env.py | 25 ++++++++++++------- .../common/envs/simxarm/simxarm/tasks/base.py | 9 ++++--- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py index dfd684cf..d9926c36 100644 --- a/lerobot/common/envs/simxarm/env.py +++ b/lerobot/common/envs/simxarm/env.py @@ -38,12 +38,14 @@ class SimxarmEnv(AbstractEnv): device="cpu", num_prev_obs=0, num_prev_action=0, - visualization_width=400, - visualization_height=400, + visualization_width=None, + visualization_height=None, ): + self.from_pixels = from_pixels + self.image_size = image_size self.visualization_width = visualization_width self.visualization_height = visualization_height - self.image_size = image_size + super().__init__( task=task, frame_skip=frame_skip, @@ -67,12 +69,17 @@ class SimxarmEnv(AbstractEnv): if self.task not in TASKS: raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}") - kwargs = { - "width": self.image_size, - "height": self.image_size, - "visualization_width": self.visualization_width, - "visualization_height": self.visualization_height, - } + kwargs = ( + { + "width": self.image_size, + "height": self.image_size, + "visualization_width": self.visualization_width, + "visualization_height": self.visualization_height, + } + if self.from_pixels + else {} + ) + self._env = TASKS[self.task]["env"](**kwargs) num_actions = len(TASKS[self.task]["action_space"]) diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/base.py b/lerobot/common/envs/simxarm/simxarm/tasks/base.py index fa7afad6..e0085944 100644 --- a/lerobot/common/envs/simxarm/simxarm/tasks/base.py +++ b/lerobot/common/envs/simxarm/simxarm/tasks/base.py @@ -24,8 +24,10 @@ class Base(robot_env.MujocoRobotEnv): self.center_of_table = np.array([1.655, 0.3, 0.63625]) self.max_z = 1.2 self.min_z = 0.2 - visualization_width = kwargs.pop("visualization_width") - visualization_height = kwargs.pop("visualization_height") + visualization_width = kwargs.pop("visualization_width") if "visualization_width" in kwargs else None + visualization_height = ( + kwargs.pop("visualization_height") if "visualization_height" in kwargs else None + ) super().__init__( model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"), @@ -35,7 +37,8 @@ class Base(robot_env.MujocoRobotEnv): **kwargs, ) - self._set_custom_size_renderer(width=visualization_width, height=visualization_height) + if visualization_width is not None and visualization_height is not None: + self._set_custom_size_renderer(width=visualization_width, height=visualization_height) @property def dt(self):