forked from tangger/lerobot
Refactor SAC policy with performance optimizations and multi-camera support
- Introduced Ensemble and CriticHead classes for more efficient critic network handling - Added support for multiple camera inputs in observation encoder - Optimized image encoding by batching image processing - Updated configuration for ManiSkill environment with reduced image size and action scaling - Compiled critic networks for improved performance - Simplified normalization and ensemble handling in critic networks Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
This commit is contained in:
@@ -10,7 +10,6 @@ from typing import Any
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
@@ -42,6 +41,7 @@ def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dic
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.image.2"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
@@ -142,7 +142,7 @@ def make_maniskill(
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
Reference in New Issue
Block a user