- Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the update_policy_parameters and remove strict=False

- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function
- Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-19 16:22:51 +00:00
parent 62e237bdee
commit 0d88a5ee09
7 changed files with 68 additions and 57 deletions

View File

@@ -3,10 +3,14 @@ import numpy as np
import gymnasium as gym
import torch
from omegaconf import DictConfig
from typing import Any
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
@@ -43,32 +47,29 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(self, observation):
return preprocess_maniskill_observation(observation)
class ManiSkillToDeviceWrapper(gym.Wrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
self.device = device
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
obs = {k: v.to(self.device) for k, v in obs.items()}
return obs, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
obs = {k: v.to(self.device) for k, v in obs.items()}
return obs, reward, terminated, truncated, info
def observation(self, observation):
observation = preprocess_maniskill_observation(observation)
observation = {k: v.to(self.device) for k, v in observation.items()}
return observation
class ManiSkillCompat(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
options = {}
return super().reset(seed=seed, options=options)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
@@ -89,7 +90,7 @@ class ManiSkillActionWrapper(gym.ActionWrapper):
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def __init__(self, env, multiply_factor: float = 10):
def __init__(self, env, multiply_factor: float = 1):
super().__init__(env)
self.multiply_factor = multiply_factor
action_space_agent: gym.spaces.Box = env.action_space[0]
@@ -108,13 +109,8 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def make_maniskill(
task: str = "PushCube-v1",
obs_mode: str = "rgb",
control_mode: str = "pd_ee_delta_pose",
render_mode: str = "rgb_array",
sensor_configs: dict[str, int] | None = None,
n_envs: int = 1,
device: torch.device = "cuda",
cfg: DictConfig,
n_envs: int | None = None,
) -> gym.Env:
"""
Factory function to create a ManiSkill environment with standard wrappers.
@@ -130,22 +126,24 @@ def make_maniskill(
Returns:
A wrapped ManiSkill environment
"""
if sensor_configs is None:
sensor_configs = {"width": 64, "height": 64}
env = gym.make(
task,
obs_mode=obs_mode,
control_mode=control_mode,
render_mode=render_mode,
sensor_configs=sensor_configs,
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
num_envs=n_envs,
)
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillObservationWrapper(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env)
env = ManiSkillToDeviceWrapper(env, device=device)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
return env