Merge pull request #6 from Cadene/user/rcadene/2024_03_04_diffusion

Make diffusion work
This commit is contained in:
Remi
2024-03-04 18:30:40 +01:00
committed by GitHub
12 changed files with 276 additions and 142 deletions

View File

@@ -1,7 +1,7 @@
defaults:
- _self_
- env: simxarm
- policy: tdmpc
- env: pusht
- policy: diffusion
hydra:
run:
@@ -21,6 +21,7 @@ save_buffer: false
train_steps: ???
fps: ???
n_action_steps: ???
env: ???
policy: ???

View File

@@ -13,7 +13,7 @@ shape_meta:
shape: [2]
horizon: 16
n_obs_steps: 1 # TODO(rcadene): before 2
n_obs_steps: 2
n_action_steps: 8
n_latency_steps: 0
dataset_obs_steps: ${n_obs_steps}
@@ -21,7 +21,7 @@ past_action_visible: False
keypoint_visible_rate: 1.0
obs_as_global_cond: True
eval_episodes: 50
eval_episodes: 1
eval_freq: 10000
save_freq: 100000
log_freq: 250
@@ -40,8 +40,8 @@ policy:
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null
diffusion_step_embed_dim: 128
down_dims: [512, 1024, 2048]
diffusion_step_embed_dim: 256 # before 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
@@ -59,10 +59,10 @@ policy:
use_ema: true
lr_scheduler: cosine
lr_warmup_steps: 500
grad_clip_norm: 0
grad_clip_norm: 10
noise_scheduler:
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
num_train_timesteps: 100
beta_start: 0.0001
beta_end: 0.02
@@ -74,16 +74,16 @@ noise_scheduler:
obs_encoder:
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
shape_meta: ${shape_meta}
resize_shape: null
crop_shape: [76, 76]
# resize_shape: null
# crop_shape: [76, 76]
# constant center crop
random_crop: True
# random_crop: True
use_group_norm: True
share_rgb_model: False
imagenet_norm: False # TODO(rcadene): was set to True
imagenet_norm: True
rgb_model:
#_target_: diffusion_policy.model.vision.model_getter.get_resnet
_target_: diffusion_policy.model.vision.model_getter.get_resnet
name: resnet18
weights: null

View File

@@ -1,5 +1,7 @@
# @package _global_
n_action_steps: 1
policy:
name: tdmpc