Compare commits
5 Commits
fix/lint_w
...
user/alibe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dbdd7d0c47 | ||
|
|
d0cd39f9b5 | ||
|
|
5bba325fd0 | ||
|
|
bf93fc1875 | ||
|
|
aef0bd8526 |
@@ -20,6 +20,8 @@ def make_env(cfg, transform=None):
|
||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
kwargs["visualization_width"] = cfg.env.visualization_width
|
||||
kwargs["visualization_height"] = cfg.env.visualization_height
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
|
||||
@@ -38,7 +38,14 @@ class SimxarmEnv(AbstractEnv):
|
||||
device="cpu",
|
||||
num_prev_obs=0,
|
||||
num_prev_action=0,
|
||||
visualization_width=None,
|
||||
visualization_height=None,
|
||||
):
|
||||
self.from_pixels = from_pixels
|
||||
self.image_size = image_size
|
||||
self.visualization_width = visualization_width
|
||||
self.visualization_height = visualization_height
|
||||
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
@@ -62,7 +69,18 @@ class SimxarmEnv(AbstractEnv):
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
kwargs = (
|
||||
{
|
||||
"width": self.image_size,
|
||||
"height": self.image_size,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
}
|
||||
if self.from_pixels
|
||||
else {}
|
||||
)
|
||||
|
||||
self._env = TASKS[self.task]["env"](**kwargs)
|
||||
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
@@ -70,12 +88,12 @@ class SimxarmEnv(AbstractEnv):
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
def render(self, mode="rgb_array"):
|
||||
return self._env.render(mode)
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||
image = self.render(mode="rgb_array")
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# from copy import deepcopy
|
||||
import os
|
||||
|
||||
import mujoco
|
||||
import numpy as np
|
||||
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
|
||||
from gymnasium_robotics.envs import robot_env
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
|
||||
@@ -15,20 +17,29 @@ class Base(robot_env.MujocoRobotEnv):
|
||||
gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
|
||||
"""
|
||||
|
||||
def __init__(self, xml_name, gripper_rotation=None):
|
||||
def __init__(self, xml_name, gripper_rotation=None, **kwargs):
|
||||
if gripper_rotation is None:
|
||||
gripper_rotation = [0, 1, 0, 0]
|
||||
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
|
||||
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
||||
self.max_z = 1.2
|
||||
self.min_z = 0.2
|
||||
visualization_width = kwargs.pop("visualization_width") if "visualization_width" in kwargs else None
|
||||
visualization_height = (
|
||||
kwargs.pop("visualization_height") if "visualization_height" in kwargs else None
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
||||
n_substeps=20,
|
||||
n_actions=4,
|
||||
initial_qpos={},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if visualization_width is not None and visualization_height is not None:
|
||||
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
|
||||
|
||||
@property
|
||||
def dt(self):
|
||||
return self.n_substeps * self.model.opt.timestep
|
||||
@@ -133,12 +144,26 @@ class Base(robot_env.MujocoRobotEnv):
|
||||
info = {"is_success": self.is_success(), "success": self.is_success()}
|
||||
return obs, reward, done, info
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
def render(self, mode="rgb_array"):
|
||||
self._render_callback()
|
||||
|
||||
if mode == "visualization":
|
||||
return self._custom_size_render()
|
||||
|
||||
return self.mujoco_renderer.render(mode, camera_name="camera0")
|
||||
|
||||
def _set_custom_size_renderer(self, width, height):
|
||||
from copy import deepcopy
|
||||
|
||||
# HACK
|
||||
self.model.vis.global_.offwidth = width
|
||||
self.model.vis.global_.offheight = height
|
||||
return self.mujoco_renderer.render(mode)
|
||||
custom_render_model = deepcopy(self.model)
|
||||
custom_render_model.vis.global_.offwidth = width
|
||||
custom_render_model.vis.global_.offheight = height
|
||||
self.custom_size_renderer = MujocoRenderer(custom_render_model, self.data)
|
||||
del custom_render_model
|
||||
|
||||
def _custom_size_render(self):
|
||||
return self.custom_size_renderer.render("rgb_array", camera_name="camera0")
|
||||
|
||||
def close(self):
|
||||
if self.mujoco_renderer is not None:
|
||||
|
||||
@@ -4,9 +4,9 @@ from lerobot.common.envs.simxarm.simxarm import Base
|
||||
|
||||
|
||||
class Lift(Base):
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
self._z_threshold = 0.15
|
||||
super().__init__("lift")
|
||||
super().__init__("lift", **kwargs)
|
||||
|
||||
@property
|
||||
def z_target(self):
|
||||
|
||||
2
lerobot/configs/env/simxarm.yaml
vendored
2
lerobot/configs/env/simxarm.yaml
vendored
@@ -20,6 +20,8 @@ env:
|
||||
action_repeat: 2
|
||||
episode_length: 25
|
||||
fps: ${fps}
|
||||
visualization_width: 400
|
||||
visualization_height: 400
|
||||
|
||||
policy:
|
||||
state_dim: 4
|
||||
|
||||
@@ -49,6 +49,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
@@ -86,7 +87,10 @@ def eval_policy(
|
||||
|
||||
def maybe_render_frame(env: EnvBase, _):
|
||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
# HACK
|
||||
# TODO(aliberts): set render_mode for all envs
|
||||
render_mode = "visualization" if isinstance(env, SimxarmEnv) else "rgb_array"
|
||||
ep_frames.append(env.render(mode=render_mode)) # noqa: B023
|
||||
|
||||
# Clear the policy's action queue before the start of a new rollout.
|
||||
if policy is not None:
|
||||
|
||||
Reference in New Issue
Block a user