ready for review

This commit is contained in:
Alexander Soare
2024-03-21 10:18:50 +00:00
parent d323993569
commit acf1174447
12 changed files with 282 additions and 85 deletions

View File

@@ -5,7 +5,6 @@ import time
import hydra
import torch
from lerobot.common.ema import update_ema_parameters
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
@@ -21,6 +20,7 @@ class DiffusionPolicy(AbstractPolicy):
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
@@ -71,8 +71,13 @@ class DiffusionPolicy(AbstractPolicy):
self.diffusion.cuda()
self.ema_diffusion = None
if self.cfg.ema.enable:
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = hydra.utils.instantiate(
cfg_ema,
model=self.ema_diffusion,
)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
@@ -175,8 +180,8 @@ class DiffusionPolicy(AbstractPolicy):
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.cfg.ema.enable:
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),