backup wip
This commit is contained in:
@@ -5,11 +5,10 @@ from collections import deque
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from torch import nn
|
||||
from diffusers.optimization import get_scheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
@@ -22,8 +21,6 @@ class DiffusionPolicy(nn.Module):
|
||||
cfg,
|
||||
cfg_device,
|
||||
cfg_noise_scheduler,
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
cfg_optimizer,
|
||||
cfg_ema,
|
||||
shape_meta: dict,
|
||||
@@ -31,53 +28,43 @@ class DiffusionPolicy(nn.Module):
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
obs_as_global_cond=True,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=(256, 512, 1024),
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
film_scale_modulation=True,
|
||||
**_,
|
||||
):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.n_action_steps = n_action_steps
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
# TODO(now): In-house this.
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
||||
if cfg_obs_encoder.crop_shape is not None:
|
||||
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
||||
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
cfg,
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
obs_encoder=obs_encoder,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
n_obs_steps=n_obs_steps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
obs_as_global_cond=obs_as_global_cond,
|
||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||
down_dims=down_dims,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
film_scale_modulation=film_scale_modulation,
|
||||
)
|
||||
|
||||
self.device = get_safe_torch_device(cfg_device)
|
||||
self.diffusion.to(self.device)
|
||||
|
||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||
self.ema_diffusion = None
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
@@ -116,42 +103,45 @@ class DiffusionPolicy(nn.Module):
|
||||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch, step):
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
||||
return self.select_action(batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch, **_):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
# TODO(now): Handle a batch
|
||||
"""
|
||||
# TODO(rcadene): remove unused step
|
||||
del step
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
obs_dict = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
self._queues["action"].extend(out["action"].transpose(0, 1))
|
||||
actions = self._generate_actions(batch)
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch, step):
|
||||
def _generate_actions(self, batch):
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
return self.ema_diffusion.predict_action(batch)
|
||||
else:
|
||||
return self.diffusion.predict_action(batch)
|
||||
|
||||
def update(self, batch, **_):
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
@@ -174,13 +164,11 @@ class DiffusionPolicy(nn.Module):
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
# TODO(rcadene): remove hardcoding
|
||||
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
|
||||
if step % 168 == 0:
|
||||
self.global_step += 1
|
||||
|
||||
return info
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
return self.diffusion.compute_loss(batch)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user