Refactor SAC policy with performance optimizations and multi-camera support

- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
This commit is contained in:
AdilZouitine
2025-02-20 17:14:27 +00:00
parent ff47c0b0d3
commit ff82367c62
4 changed files with 153 additions and 93 deletions

View File

@@ -10,7 +10,6 @@ from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
@@ -42,6 +41,7 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.image.2"] = img
return_observations["observation.state"] = state
return return_observations
@@ -142,7 +142,7 @@ def make_maniskill(
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1)
return env