backup wip

This commit is contained in:
Alexander Soare
2024-04-11 17:51:35 +01:00
parent 91ff69d64c
commit 976a197f98
26 changed files with 661 additions and 2733 deletions

View File

@@ -5,11 +5,10 @@ from collections import deque
import hydra
import torch
from torch import nn
from diffusers.optimization import get_scheduler
from torch import Tensor, nn
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
@@ -22,8 +21,6 @@ class DiffusionPolicy(nn.Module):
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
@@ -31,53 +28,43 @@ class DiffusionPolicy(nn.Module):
n_action_steps,
n_obs_steps,
num_inference_steps=None,
obs_as_global_cond=True,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
cond_predict_scale=True,
# parameters passed to step
**kwargs,
film_scale_modulation=True,
**_,
):
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
# TODO(now): In-house this.
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
)
self.diffusion = DiffusionUnetImagePolicy(
cfg,
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
obs_as_global_cond=obs_as_global_cond,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
# parameters passed to step
**kwargs,
film_scale_modulation=film_scale_modulation,
)
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
@@ -116,42 +103,45 @@ class DiffusionPolicy(nn.Module):
"action": deque(maxlen=self.n_action_steps),
}
@torch.no_grad()
def select_action(self, batch, step):
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
return self.select_action(batch)
@torch.no_grad
def select_action(self, batch, **_):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
# TODO(now): Handle a batch
"""
# TODO(rcadene): remove unused step
del step
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
obs_dict = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
self._queues["action"].extend(out["action"].transpose(0, 1))
actions = self._generate_actions(batch)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch, step):
def _generate_actions(self, batch):
if not self.training and self.ema_diffusion is not None:
return self.ema_diffusion.predict_action(batch)
else:
return self.diffusion.predict_action(batch)
def update(self, batch, **_):
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
loss = self.diffusion.compute_loss(batch)
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -174,13 +164,11 @@ class DiffusionPolicy(nn.Module):
"update_s": time.time() - start_time,
}
# TODO(rcadene): remove hardcoding
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
if step % 168 == 0:
self.global_step += 1
return info
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
return self.diffusion.compute_loss(batch)
def save(self, fp):
torch.save(self.state_dict(), fp)