Add diffusion policy (train and eval works, TODO: reproduce results)

This commit is contained in:
Cadene
2024-02-28 15:21:30 +00:00
parent f1708c8a37
commit cf5063e50e
5 changed files with 125 additions and 31 deletions

View File

@@ -13,7 +13,7 @@ shape_meta:
shape: [2]
horizon: 16
n_obs_steps: 2
n_obs_steps: 1 # TODO(rcadene): before 2
n_action_steps: 8
n_latency_steps: 0
dataset_obs_steps: ${n_obs_steps}
@@ -51,6 +51,10 @@ policy:
balanced_sampling: true
utd: 1
offline_steps: ${offline_steps}
use_ema: true
lr_scheduler: cosine
lr_warmup_steps: 500
noise_scheduler:
# _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
@@ -99,13 +103,13 @@ training:
debug: False
resume: True
# optimization
lr_scheduler: cosine
lr_warmup_steps: 500
# lr_scheduler: cosine
# lr_warmup_steps: 500
num_epochs: 8000
gradient_accumulate_every: 1
# gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm.
use_ema: True
# use_ema: True
freeze_encoder: False
# training loop control
# in epochs

View File

@@ -62,7 +62,7 @@ policy:
A_scaling: 3.0
# offline->online
offline_steps: 25000 # ${train_steps}/2
offline_steps: ${offline_steps}
pretrained_model_path: ""
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# pretrained_model_path: "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
@@ -73,4 +73,4 @@ policy:
enc_dim: 256
num_q: 5
mlp_dim: 512
latent_dim: 50
latent_dim: 50