Stable version of rlpd + drq

This commit is contained in:
AdilZouitine
2025-01-22 09:00:16 +00:00
committed by Michel Aractingi
parent ef777993cd
commit c8b1132846
6 changed files with 425 additions and 174 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from collections import deque
import gymnasium as gym
@@ -67,3 +68,86 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
)
return env
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
images = []
for cam_data in sensor_data.values():
images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
ret["state"] = observation
ret["pixels"] = images
return ret

View File

@@ -37,6 +37,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
@@ -60,6 +63,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
img /= 255
return_observations[imgkey] = img
# obs state agent qpos and qvel
# image
if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(