Refactor SAC policy with performance optimizations and multi-camera support
- Introduced Ensemble and CriticHead classes for more efficient critic network handling - Added support for multiple camera inputs in observation encoder - Optimized image encoding by batching image processing - Updated configuration for ManiSkill environment with reduced image size and action scaling - Compiled critic networks for improved performance - Simplified normalization and ensemble handling in critic networks Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
This commit is contained in:
4
lerobot/configs/env/maniskill_example.yaml
vendored
4
lerobot/configs/env/maniskill_example.yaml
vendored
@@ -5,14 +5,14 @@ fps: 20
|
||||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 128
|
||||
image_size: 64
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 128
|
||||
render_size: 64
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
|
||||
@@ -59,32 +59,36 @@ policy:
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 128, 128]
|
||||
observation.image: [3, 64, 64]
|
||||
observation.image.2: [3, 64, 64]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
|
||||
camera_number: 2
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.state:
|
||||
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
input_normalization_modes: null
|
||||
# input_normalization_modes:
|
||||
# observation.state: min_max
|
||||
input_normalization_params: null
|
||||
# observation.state:
|
||||
# min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
# 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
# -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
# -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
# 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
|
||||
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
0.4001]
|
||||
# max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
# 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
# 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
# 0.4001]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
output_normalization_shapes:
|
||||
action: [7]
|
||||
|
||||
@@ -94,8 +98,8 @@ policy:
|
||||
# discount: 0.99
|
||||
discount: 0.80
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
num_subsample_critics: null
|
||||
num_critics: 10 #10
|
||||
num_subsample_critics: 2
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
|
||||
Reference in New Issue
Block a user