backup wip
This commit is contained in:
@@ -161,6 +161,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
return self.select_action(self, batch)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
@@ -172,23 +175,17 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if len(self._action_queue) == 0:
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has
|
||||
# shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
|
||||
self._action_queue.extend(self._select_actions(batch).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad
|
||||
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def _select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||
self.eval()
|
||||
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
||||
|
||||
action = self.forward(batch, return_loss=False)
|
||||
|
||||
return action[: self.cfg.n_action_steps]
|
||||
|
||||
def __call__(self, *args, **kwargs) -> dict:
|
||||
# TODO(alexander-soare): Temporary bridge until we know what to do about the `update` method.
|
||||
return self.update(*args, **kwargs)
|
||||
|
||||
def _preprocess_batch(
|
||||
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
|
||||
) -> dict[str, Tensor]:
|
||||
@@ -216,9 +213,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
def compute_loss(self, batch, **_) -> float:
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
self.train()
|
||||
@@ -230,6 +225,12 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
loss = self.forward(batch, return_loss=True)["loss"]
|
||||
return loss
|
||||
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
||||
@@ -251,7 +251,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy(batch, step=step)
|
||||
train_info = policy.update(batch, step=step)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
@@ -313,7 +313,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy(batch, step)
|
||||
train_info = policy.update(batch, step)
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||
|
||||
Reference in New Issue
Block a user