Files
lerobot_piper/lerobot/common/policies/diffusion/policy.py
Alexander Soare 976a197f98 backup wip
2024-04-11 17:51:35 +01:00

184 lines
5.9 KiB
Python

import copy
import logging
import time
from collections import deque
import hydra
import torch
from diffusers.optimization import get_scheduler
from torch import Tensor, nn
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
self,
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
film_scale_modulation=True,
**_,
):
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
# TODO(now): In-house this.
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
self.diffusion = DiffusionUnetImagePolicy(
cfg,
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = hydra.utils.instantiate(
cfg_ema,
model=self.ema_diffusion,
)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
params=self.diffusion.parameters(),
)
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
self.global_step = 0
# configure lr scheduler
self.lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=self.global_step - 1,
)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
return self.select_action(batch)
@torch.no_grad
def select_action(self, batch, **_):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
# TODO(now): Handle a batch
"""
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self._generate_actions(batch)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def _generate_actions(self, batch):
if not self.training and self.ema_diffusion is not None:
return self.ema_diffusion.predict_action(batch)
else:
return self.diffusion.predict_action(batch)
def update(self, batch, **_):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
return self.diffusion.compute_loss(batch)
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0