Remove image size hack
This commit is contained in:
@@ -39,6 +39,7 @@ class SimxarmEnv(AbstractEnv):
|
|||||||
num_prev_obs=0,
|
num_prev_obs=0,
|
||||||
num_prev_action=0,
|
num_prev_action=0,
|
||||||
):
|
):
|
||||||
|
self.image_size = image_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task=task,
|
task=task,
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
@@ -62,7 +63,8 @@ class SimxarmEnv(AbstractEnv):
|
|||||||
if self.task not in TASKS:
|
if self.task not in TASKS:
|
||||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||||
|
|
||||||
self._env = TASKS[self.task]["env"]()
|
kwargs = {"width": self.image_size, "height": self.image_size}
|
||||||
|
self._env = TASKS[self.task]["env"](**kwargs)
|
||||||
|
|
||||||
num_actions = len(TASKS[self.task]["action_space"])
|
num_actions = len(TASKS[self.task]["action_space"])
|
||||||
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class Base(robot_env.MujocoRobotEnv):
|
|||||||
gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
|
gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, xml_name, gripper_rotation=None):
|
def __init__(self, xml_name, gripper_rotation=None, **kwargs):
|
||||||
if gripper_rotation is None:
|
if gripper_rotation is None:
|
||||||
gripper_rotation = [0, 1, 0, 0]
|
gripper_rotation = [0, 1, 0, 0]
|
||||||
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
|
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
|
||||||
@@ -27,6 +27,7 @@ class Base(robot_env.MujocoRobotEnv):
|
|||||||
n_substeps=20,
|
n_substeps=20,
|
||||||
n_actions=4,
|
n_actions=4,
|
||||||
initial_qpos={},
|
initial_qpos={},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -135,9 +136,6 @@ class Base(robot_env.MujocoRobotEnv):
|
|||||||
|
|
||||||
def render(self, mode="rgb_array", width=384, height=384):
|
def render(self, mode="rgb_array", width=384, height=384):
|
||||||
self._render_callback()
|
self._render_callback()
|
||||||
# HACK
|
|
||||||
self.model.vis.global_.offwidth = width
|
|
||||||
self.model.vis.global_.offheight = height
|
|
||||||
return self.mujoco_renderer.render(mode, camera_name="camera0")
|
return self.mujoco_renderer.render(mode, camera_name="camera0")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ from lerobot.common.envs.simxarm.simxarm import Base
|
|||||||
|
|
||||||
|
|
||||||
class Lift(Base):
|
class Lift(Base):
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
self._z_threshold = 0.15
|
self._z_threshold = 0.15
|
||||||
super().__init__("lift")
|
super().__init__("lift", **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def z_target(self):
|
def z_target(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user