Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)
This commit is contained in:
@@ -4,9 +4,29 @@ def make_policy(cfg):
|
||||
|
||||
policy = TDMPC(cfg.policy)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import (
|
||||
MultiImageObsEncoder,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.diffusion import DiffusionPolicy
|
||||
|
||||
policy = DiffusionPolicy(cfg.policy)
|
||||
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
|
||||
|
||||
rgb_model = get_resnet(**cfg.rgb_model)
|
||||
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg.obs_encoder,
|
||||
)
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
noise_scheduler=noise_scheduler,
|
||||
obs_encoder=obs_encoder,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user