Refactor SAC policy with performance optimizations and multi-camera support

- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
This commit is contained in:
AdilZouitine
2025-02-20 17:14:27 +00:00
parent 795063aa1b
commit 150def839c
4 changed files with 153 additions and 93 deletions

View File

@@ -5,14 +5,14 @@ fps: 20
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 128
image_size: 64
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 128
render_size: 64
device: cuda
reward_classifier:

View File

@@ -59,32 +59,36 @@ policy:
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 128, 128]
observation.image: [3, 64, 64]
observation.image.2: [3, 64, 64]
output_shapes:
action: [7]
camera_number: 2
# Normalization / Unnormalization
input_normalization_modes:
observation.state: min_max
input_normalization_params:
observation.state:
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
input_normalization_modes: null
# input_normalization_modes:
# observation.state: min_max
input_normalization_params: null
# observation.state:
# min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
0.4001]
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
# 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
# 0.4001]
output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
output_normalization_shapes:
action: [7]
@@ -94,8 +98,8 @@ policy:
# discount: 0.99
discount: 0.80
temperature_init: 1.0
num_critics: 2 #10
num_subsample_critics: null
num_critics: 10 #10
num_subsample_critics: 2
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4