backup wip

This commit is contained in:
Alexander Soare
2024-03-20 15:01:27 +00:00
parent 32e3f71dd1
commit d323993569
7 changed files with 71 additions and 81 deletions

View File

@@ -1,9 +1,11 @@
import copy
import logging
import time
import hydra
import torch
from lerobot.common.ema import update_ema_parameters
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
@@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy):
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
@@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy):
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
@@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy):
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()
self.ema = None
if self.cfg.use_ema:
self.ema = hydra.utils.instantiate(
cfg_ema,
model=copy.deepcopy(self.diffusion),
)
self.ema_diffusion = None
if self.cfg.ema.enable:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
@@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy):
@torch.no_grad()
def select_actions(self, observation, step_count):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
# TODO(rcadene): remove unused step_count
del step_count
@@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy):
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
action = out["action"]
return action
@@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy):
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
if self.cfg.ema.enable:
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
info = {
"loss": loss.item(),
@@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy):
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0