ready for review
This commit is contained in:
@@ -6,7 +6,7 @@ from collections import deque
|
||||
import hydra
|
||||
import torch
|
||||
from diffusers.optimization import get_scheduler
|
||||
from torch import Tensor, nn
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
@@ -43,7 +43,6 @@ class DiffusionPolicy(nn.Module):
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
# TODO(now): In-house this.
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
@@ -103,45 +102,35 @@ class DiffusionPolicy(nn.Module):
|
||||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
||||
return self.select_action(batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch, **_):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
# TODO(now): Handle a batch
|
||||
"""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2 # TODO(now): Does this not have a batch dim?
|
||||
assert len(batch) == 2
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
actions = self._generate_actions(batch)
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def _generate_actions(self, batch):
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
return self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
return self.diffusion.generate_actions(batch)
|
||||
|
||||
def update(self, batch, **_):
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
def forward(self, batch, **_):
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
@@ -166,9 +155,6 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
return info
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
return self.diffusion.compute_loss(batch)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user