Add diffusion policy (train and eval works, TODO: reproduce results)

This commit is contained in:
Cadene
2024-02-28 15:21:30 +00:00
parent f1708c8a37
commit cf5063e50e
5 changed files with 125 additions and 31 deletions

View File

@@ -4,26 +4,15 @@ def make_policy(cfg):
policy = TDMPC(cfg.policy)
elif cfg.policy.name == "diffusion":
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import (
MultiImageObsEncoder,
)
from lerobot.common.policies.diffusion import DiffusionPolicy
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
rgb_model = get_resnet(**cfg.rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg.obs_encoder,
)
policy = DiffusionPolicy(
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
cfg=cfg.policy,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)