forked from tangger/lerobot
Fix tests (when from_pixel=false)
This commit is contained in:
@@ -38,12 +38,14 @@ class SimxarmEnv(AbstractEnv):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
num_prev_obs=0,
|
num_prev_obs=0,
|
||||||
num_prev_action=0,
|
num_prev_action=0,
|
||||||
visualization_width=400,
|
visualization_width=None,
|
||||||
visualization_height=400,
|
visualization_height=None,
|
||||||
):
|
):
|
||||||
|
self.from_pixels = from_pixels
|
||||||
|
self.image_size = image_size
|
||||||
self.visualization_width = visualization_width
|
self.visualization_width = visualization_width
|
||||||
self.visualization_height = visualization_height
|
self.visualization_height = visualization_height
|
||||||
self.image_size = image_size
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task=task,
|
task=task,
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
@@ -67,12 +69,17 @@ class SimxarmEnv(AbstractEnv):
|
|||||||
if self.task not in TASKS:
|
if self.task not in TASKS:
|
||||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||||
|
|
||||||
kwargs = {
|
kwargs = (
|
||||||
"width": self.image_size,
|
{
|
||||||
"height": self.image_size,
|
"width": self.image_size,
|
||||||
"visualization_width": self.visualization_width,
|
"height": self.image_size,
|
||||||
"visualization_height": self.visualization_height,
|
"visualization_width": self.visualization_width,
|
||||||
}
|
"visualization_height": self.visualization_height,
|
||||||
|
}
|
||||||
|
if self.from_pixels
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
self._env = TASKS[self.task]["env"](**kwargs)
|
self._env = TASKS[self.task]["env"](**kwargs)
|
||||||
|
|
||||||
num_actions = len(TASKS[self.task]["action_space"])
|
num_actions = len(TASKS[self.task]["action_space"])
|
||||||
|
|||||||
@@ -24,8 +24,10 @@ class Base(robot_env.MujocoRobotEnv):
|
|||||||
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
||||||
self.max_z = 1.2
|
self.max_z = 1.2
|
||||||
self.min_z = 0.2
|
self.min_z = 0.2
|
||||||
visualization_width = kwargs.pop("visualization_width")
|
visualization_width = kwargs.pop("visualization_width") if "visualization_width" in kwargs else None
|
||||||
visualization_height = kwargs.pop("visualization_height")
|
visualization_height = (
|
||||||
|
kwargs.pop("visualization_height") if "visualization_height" in kwargs else None
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
||||||
@@ -35,7 +37,8 @@ class Base(robot_env.MujocoRobotEnv):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
|
if visualization_width is not None and visualization_height is not None:
|
||||||
|
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self):
|
def dt(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user