From bf93fc18757e3efc120fa393237d4f0ad48b4370 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 31 Mar 2024 14:55:13 +0200 Subject: [PATCH] Remove image size hack --- lerobot/common/envs/simxarm/env.py | 4 +++- lerobot/common/envs/simxarm/simxarm/tasks/base.py | 6 ++---- lerobot/common/envs/simxarm/simxarm/tasks/lift.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py index b81bf4992..b8f19057d 100644 --- a/lerobot/common/envs/simxarm/env.py +++ b/lerobot/common/envs/simxarm/env.py @@ -39,6 +39,7 @@ class SimxarmEnv(AbstractEnv): num_prev_obs=0, num_prev_action=0, ): + self.image_size = image_size super().__init__( task=task, frame_skip=frame_skip, @@ -62,7 +63,8 @@ class SimxarmEnv(AbstractEnv): if self.task not in TASKS: raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}") - self._env = TASKS[self.task]["env"]() + kwargs = {"width": self.image_size, "height": self.image_size} + self._env = TASKS[self.task]["env"](**kwargs) num_actions = len(TASKS[self.task]["action_space"]) self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/base.py b/lerobot/common/envs/simxarm/simxarm/tasks/base.py index 5cb157efb..167dafe8f 100644 --- a/lerobot/common/envs/simxarm/simxarm/tasks/base.py +++ b/lerobot/common/envs/simxarm/simxarm/tasks/base.py @@ -15,7 +15,7 @@ class Base(robot_env.MujocoRobotEnv): gripper_rotation (list): initial rotation of the gripper (given as a quaternion) """ - def __init__(self, xml_name, gripper_rotation=None): + def __init__(self, xml_name, gripper_rotation=None, **kwargs): if gripper_rotation is None: gripper_rotation = [0, 1, 0, 0] self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32) @@ -27,6 +27,7 @@ class Base(robot_env.MujocoRobotEnv): n_substeps=20, n_actions=4, initial_qpos={}, + **kwargs, ) @property @@ -135,9 +136,6 @@ class Base(robot_env.MujocoRobotEnv): def render(self, mode="rgb_array", width=384, height=384): self._render_callback() - # HACK - self.model.vis.global_.offwidth = width - self.model.vis.global_.offheight = height return self.mujoco_renderer.render(mode, camera_name="camera0") def close(self): diff --git a/lerobot/common/envs/simxarm/simxarm/tasks/lift.py b/lerobot/common/envs/simxarm/simxarm/tasks/lift.py index 0b11196c7..c59d062c2 100644 --- a/lerobot/common/envs/simxarm/simxarm/tasks/lift.py +++ b/lerobot/common/envs/simxarm/simxarm/tasks/lift.py @@ -4,9 +4,9 @@ from lerobot.common.envs.simxarm.simxarm import Base class Lift(Base): - def __init__(self): + def __init__(self, **kwargs): self._z_threshold = 0.15 - super().__init__("lift") + super().__init__("lift", **kwargs) @property def z_target(self):