Add maniskill support.

Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
This commit is contained in:
AdilZouitine
2025-02-14 19:53:29 +00:00
parent 291358d6a2
commit b7a0ffc3b8
6 changed files with 222 additions and 27 deletions

View File

@@ -5,11 +5,16 @@ fps: 20
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 64
image_size: 128
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 64
render_size: 128
device: cuda
reward_classifier:
pretrained_path: null
config_path: null

View File

@@ -8,7 +8,7 @@
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: aractingi/hil-serl-maniskill-pushcube
dataset_repo_id: null
training:
# Offline training dataloader
@@ -20,7 +20,7 @@ training:
lr: 3e-4
eval_freq: 2500
log_freq: 500
log_freq: 10
save_freq: 2000000
online_steps: 1000000
@@ -52,14 +52,16 @@ policy:
n_action_steps: 1
shared_encoder: true
# vision_encoder_name: null
vision_encoder_name: null
# vision_encoder_name: "helper2424/resnet10"
# freeze_vision_encoder: true
freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 64, 64]
observation.image: [3, 128, 128]
output_shapes:
action: ["${env.action_dim}"]
action: [7]
# Normalization / Unnormalization
input_normalization_modes: null
@@ -67,8 +69,8 @@ policy:
action: min_max
output_normalization_params:
action:
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
# Architecture / modeling.
# Neural networks.
@@ -88,14 +90,3 @@ policy:
actor_learner_config:
actor_ip: "127.0.0.1"
port: 50051
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995