Add diffusion policy (train and eval works, TODO: reproduce results)

This commit is contained in:
Cadene
2024-02-28 15:21:30 +00:00
parent f1708c8a37
commit cf5063e50e
5 changed files with 125 additions and 31 deletions

View File

@@ -1,7 +1,12 @@
import copy
import hydra
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
@@ -10,9 +15,13 @@ class DiffusionPolicy(nn.Module):
def __init__(
self,
cfg,
cfg_noise_scheduler,
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
obs_encoder: MultiImageObsEncoder,
horizon,
n_action_steps,
n_obs_steps,
@@ -27,6 +36,15 @@ class DiffusionPolicy(nn.Module):
**kwargs,
):
super().__init__()
self.cfg = cfg
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler)
rgb_model = get_resnet(**cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
)
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
@@ -44,3 +62,91 @@ class DiffusionPolicy(nn.Module):
# parameters passed to step
**kwargs,
)
self.device = torch.device("cuda")
self.diffusion.cuda()
self.ema = None
if self.cfg.use_ema:
self.ema = hydra.utils.instantiate(
cfg_ema,
model=copy.deepcopy(self.diffusion),
)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
params=self.diffusion.parameters(),
)
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
self.global_step = 0
# configure lr scheduler
self.lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=self.global_step - 1,
)
@torch.no_grad()
def forward(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
obs_dict = {
# c h w -> b t c h w (b=1, t=1)
"image": observation["image"][None, None, ...],
"agent_pos": observation["state"][None, None, ...],
}
out = self.diffusion.predict_action(obs_dict)
# TODO(rcadene): add possibility to return >1 timestemps
FIRST_ACTION = 0
action = out["action"].squeeze(0)[FIRST_ACTION]
return action
def update(self, replay_buffer, step):
self.diffusion.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
out = {
"obs": {
"image": batch["observation", "image"].to(self.device),
"agent_pos": batch["observation", "state"].to(self.device),
},
"action": batch["action"].to(self.device),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
loss = self.diffusion.compute_loss(batch)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
metrics = {
"total_loss": loss.item(),
"lr": self.lr_scheduler.get_last_lr()[0],
}
return metrics

View File

@@ -4,26 +4,15 @@ def make_policy(cfg):
policy = TDMPC(cfg.policy)
elif cfg.policy.name == "diffusion":
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import (
MultiImageObsEncoder,
)
from lerobot.common.policies.diffusion import DiffusionPolicy
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
rgb_model = get_resnet(**cfg.rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg.obs_encoder,
)
policy = DiffusionPolicy(
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
cfg=cfg.policy,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)