Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)

This commit is contained in:
Cadene
2024-02-26 01:10:09 +00:00
parent 5a219fed6e
commit 21670dce90
12 changed files with 306 additions and 443 deletions

View File

@@ -4,9 +4,29 @@ def make_policy(cfg):
policy = TDMPC(cfg.policy)
elif cfg.policy.name == "diffusion":
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import (
MultiImageObsEncoder,
)
from lerobot.common.policies.diffusion import DiffusionPolicy
policy = DiffusionPolicy(cfg.policy)
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
rgb_model = get_resnet(**cfg.rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg.obs_encoder,
)
policy = DiffusionPolicy(
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
else:
raise ValueError(cfg.policy.name)