Fix tests (when from_pixel=false)

This commit is contained in:
Simon Alibert
2024-03-31 18:26:32 +02:00
parent 5bba325fd0
commit d0cd39f9b5
2 changed files with 22 additions and 12 deletions

View File

@@ -38,12 +38,14 @@ class SimxarmEnv(AbstractEnv):
device="cpu",
num_prev_obs=0,
num_prev_action=0,
visualization_width=400,
visualization_height=400,
visualization_width=None,
visualization_height=None,
):
self.from_pixels = from_pixels
self.image_size = image_size
self.visualization_width = visualization_width
self.visualization_height = visualization_height
self.image_size = image_size
super().__init__(
task=task,
frame_skip=frame_skip,
@@ -67,12 +69,17 @@ class SimxarmEnv(AbstractEnv):
if self.task not in TASKS:
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
kwargs = {
"width": self.image_size,
"height": self.image_size,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
}
kwargs = (
{
"width": self.image_size,
"height": self.image_size,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
}
if self.from_pixels
else {}
)
self._env = TASKS[self.task]["env"](**kwargs)
num_actions = len(TASKS[self.task]["action_space"])

View File

@@ -24,8 +24,10 @@ class Base(robot_env.MujocoRobotEnv):
self.center_of_table = np.array([1.655, 0.3, 0.63625])
self.max_z = 1.2
self.min_z = 0.2
visualization_width = kwargs.pop("visualization_width")
visualization_height = kwargs.pop("visualization_height")
visualization_width = kwargs.pop("visualization_width") if "visualization_width" in kwargs else None
visualization_height = (
kwargs.pop("visualization_height") if "visualization_height" in kwargs else None
)
super().__init__(
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
@@ -35,7 +37,8 @@ class Base(robot_env.MujocoRobotEnv):
**kwargs,
)
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
if visualization_width is not None and visualization_height is not None:
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
@property
def dt(self):