WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
import copy
|
|
import logging
|
|
import time
|
|
from collections import deque
|
|
|
|
import hydra
|
|
import torch
|
|
from torch import nn
|
|
|
|
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
|
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
|
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
|
from lerobot.common.policies.utils import populate_queues
|
|
from lerobot.common.utils import get_safe_torch_device
|
|
|
|
|
|
class DiffusionPolicy(nn.Module):
|
|
name = "diffusion"
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
cfg_device,
|
|
cfg_noise_scheduler,
|
|
cfg_rgb_model,
|
|
cfg_obs_encoder,
|
|
cfg_optimizer,
|
|
cfg_ema,
|
|
shape_meta: dict,
|
|
horizon,
|
|
n_action_steps,
|
|
n_obs_steps,
|
|
num_inference_steps=None,
|
|
obs_as_global_cond=True,
|
|
diffusion_step_embed_dim=256,
|
|
down_dims=(256, 512, 1024),
|
|
kernel_size=5,
|
|
n_groups=8,
|
|
cond_predict_scale=True,
|
|
# parameters passed to step
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.n_obs_steps = n_obs_steps
|
|
self.n_action_steps = n_action_steps
|
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
|
self._queues = None
|
|
|
|
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
|
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
|
if cfg_obs_encoder.crop_shape is not None:
|
|
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
|
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
|
obs_encoder = MultiImageObsEncoder(
|
|
rgb_model=rgb_model,
|
|
**cfg_obs_encoder,
|
|
)
|
|
|
|
self.diffusion = DiffusionUnetImagePolicy(
|
|
shape_meta=shape_meta,
|
|
noise_scheduler=noise_scheduler,
|
|
obs_encoder=obs_encoder,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
n_obs_steps=n_obs_steps,
|
|
num_inference_steps=num_inference_steps,
|
|
obs_as_global_cond=obs_as_global_cond,
|
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
down_dims=down_dims,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
# parameters passed to step
|
|
**kwargs,
|
|
)
|
|
|
|
self.device = get_safe_torch_device(cfg_device)
|
|
self.diffusion.to(self.device)
|
|
|
|
self.ema_diffusion = None
|
|
self.ema = None
|
|
if self.cfg.use_ema:
|
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
|
self.ema = hydra.utils.instantiate(
|
|
cfg_ema,
|
|
model=self.ema_diffusion,
|
|
)
|
|
|
|
self.optimizer = hydra.utils.instantiate(
|
|
cfg_optimizer,
|
|
params=self.diffusion.parameters(),
|
|
)
|
|
|
|
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
|
|
self.global_step = 0
|
|
|
|
# configure lr scheduler
|
|
self.lr_scheduler = get_scheduler(
|
|
cfg.lr_scheduler,
|
|
optimizer=self.optimizer,
|
|
num_warmup_steps=cfg.lr_warmup_steps,
|
|
num_training_steps=cfg.offline_steps,
|
|
# pytorch assumes stepping LRScheduler every epoch
|
|
# however huggingface diffusers steps it every batch
|
|
last_epoch=self.global_step - 1,
|
|
)
|
|
|
|
def reset(self):
|
|
"""
|
|
Clear observation and action queues. Should be called on `env.reset()`
|
|
"""
|
|
self._queues = {
|
|
"observation.image": deque(maxlen=self.n_obs_steps),
|
|
"observation.state": deque(maxlen=self.n_obs_steps),
|
|
"action": deque(maxlen=self.n_action_steps),
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, batch, step):
|
|
"""
|
|
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
|
"""
|
|
# TODO(rcadene): remove unused step
|
|
del step
|
|
assert "observation.image" in batch
|
|
assert "observation.state" in batch
|
|
assert len(batch) == 2
|
|
|
|
self._queues = populate_queues(self._queues, batch)
|
|
|
|
if len(self._queues["action"]) == 0:
|
|
# stack n latest observations from the queue
|
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
|
|
|
obs_dict = {
|
|
"image": batch["observation.image"],
|
|
"agent_pos": batch["observation.state"],
|
|
}
|
|
if self.training:
|
|
out = self.diffusion.predict_action(obs_dict)
|
|
else:
|
|
out = self.ema_diffusion.predict_action(obs_dict)
|
|
self._queues["action"].extend(out["action"].transpose(0, 1))
|
|
|
|
action = self._queues["action"].popleft()
|
|
return action
|
|
|
|
def forward(self, batch, step):
|
|
start_time = time.time()
|
|
|
|
self.diffusion.train()
|
|
|
|
data_s = time.time() - start_time
|
|
|
|
obs_dict = {
|
|
"image": batch["observation.image"],
|
|
"agent_pos": batch["observation.state"],
|
|
}
|
|
loss = self.diffusion.compute_loss(obs_dict)
|
|
loss.backward()
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
self.diffusion.parameters(),
|
|
self.cfg.grad_clip_norm,
|
|
error_if_nonfinite=False,
|
|
)
|
|
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
self.lr_scheduler.step()
|
|
|
|
if self.ema is not None:
|
|
self.ema.step(self.diffusion)
|
|
|
|
info = {
|
|
"loss": loss.item(),
|
|
"grad_norm": float(grad_norm),
|
|
"lr": self.lr_scheduler.get_last_lr()[0],
|
|
"data_s": data_s,
|
|
"update_s": time.time() - start_time,
|
|
}
|
|
|
|
# TODO(rcadene): remove hardcoding
|
|
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
|
|
if step % 168 == 0:
|
|
self.global_step += 1
|
|
|
|
return info
|
|
|
|
def save(self, fp):
|
|
torch.save(self.state_dict(), fp)
|
|
|
|
def load(self, fp):
|
|
d = torch.load(fp)
|
|
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
|
if len(missing_keys) > 0:
|
|
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
|
logging.warning(
|
|
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
|
)
|
|
assert len(unexpected_keys) == 0
|