forked from tangger/lerobot
Fix tests (when from_pixel=false)
This commit is contained in:
@@ -38,12 +38,14 @@ class SimxarmEnv(AbstractEnv):
|
||||
device="cpu",
|
||||
num_prev_obs=0,
|
||||
num_prev_action=0,
|
||||
visualization_width=400,
|
||||
visualization_height=400,
|
||||
visualization_width=None,
|
||||
visualization_height=None,
|
||||
):
|
||||
self.from_pixels = from_pixels
|
||||
self.image_size = image_size
|
||||
self.visualization_width = visualization_width
|
||||
self.visualization_height = visualization_height
|
||||
self.image_size = image_size
|
||||
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
@@ -67,12 +69,17 @@ class SimxarmEnv(AbstractEnv):
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||
|
||||
kwargs = {
|
||||
"width": self.image_size,
|
||||
"height": self.image_size,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
}
|
||||
kwargs = (
|
||||
{
|
||||
"width": self.image_size,
|
||||
"height": self.image_size,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
}
|
||||
if self.from_pixels
|
||||
else {}
|
||||
)
|
||||
|
||||
self._env = TASKS[self.task]["env"](**kwargs)
|
||||
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
|
||||
@@ -24,8 +24,10 @@ class Base(robot_env.MujocoRobotEnv):
|
||||
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
||||
self.max_z = 1.2
|
||||
self.min_z = 0.2
|
||||
visualization_width = kwargs.pop("visualization_width")
|
||||
visualization_height = kwargs.pop("visualization_height")
|
||||
visualization_width = kwargs.pop("visualization_width") if "visualization_width" in kwargs else None
|
||||
visualization_height = (
|
||||
kwargs.pop("visualization_height") if "visualization_height" in kwargs else None
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
||||
@@ -35,7 +37,8 @@ class Base(robot_env.MujocoRobotEnv):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
|
||||
if visualization_width is not None and visualization_height is not None:
|
||||
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
|
||||
|
||||
@property
|
||||
def dt(self):
|
||||
|
||||
Reference in New Issue
Block a user