Merge pull request #6 from Cadene/user/rcadene/2024_03_04_diffusion
Make diffusion work
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
defaults:
|
||||
- _self_
|
||||
- env: simxarm
|
||||
- policy: tdmpc
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
|
||||
hydra:
|
||||
run:
|
||||
@@ -21,6 +21,7 @@ save_buffer: false
|
||||
train_steps: ???
|
||||
fps: ???
|
||||
|
||||
n_action_steps: ???
|
||||
env: ???
|
||||
|
||||
policy: ???
|
||||
|
||||
@@ -13,7 +13,7 @@ shape_meta:
|
||||
shape: [2]
|
||||
|
||||
horizon: 16
|
||||
n_obs_steps: 1 # TODO(rcadene): before 2
|
||||
n_obs_steps: 2
|
||||
n_action_steps: 8
|
||||
n_latency_steps: 0
|
||||
dataset_obs_steps: ${n_obs_steps}
|
||||
@@ -21,7 +21,7 @@ past_action_visible: False
|
||||
keypoint_visible_rate: 1.0
|
||||
obs_as_global_cond: True
|
||||
|
||||
eval_episodes: 50
|
||||
eval_episodes: 1
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
@@ -40,8 +40,8 @@ policy:
|
||||
num_inference_steps: 100
|
||||
obs_as_global_cond: ${obs_as_global_cond}
|
||||
# crop_shape: null
|
||||
diffusion_step_embed_dim: 128
|
||||
down_dims: [512, 1024, 2048]
|
||||
diffusion_step_embed_dim: 256 # before 128
|
||||
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
cond_predict_scale: True
|
||||
@@ -59,10 +59,10 @@ policy:
|
||||
use_ema: true
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
grad_clip_norm: 0
|
||||
grad_clip_norm: 10
|
||||
|
||||
noise_scheduler:
|
||||
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||
num_train_timesteps: 100
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
@@ -74,16 +74,16 @@ noise_scheduler:
|
||||
obs_encoder:
|
||||
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
shape_meta: ${shape_meta}
|
||||
resize_shape: null
|
||||
crop_shape: [76, 76]
|
||||
# resize_shape: null
|
||||
# crop_shape: [76, 76]
|
||||
# constant center crop
|
||||
random_crop: True
|
||||
# random_crop: True
|
||||
use_group_norm: True
|
||||
share_rgb_model: False
|
||||
imagenet_norm: False # TODO(rcadene): was set to True
|
||||
imagenet_norm: True
|
||||
|
||||
rgb_model:
|
||||
#_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
||||
_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
||||
name: resnet18
|
||||
weights: null
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# @package _global_
|
||||
|
||||
n_action_steps: 1
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user