backup wip

This commit is contained in:
Alexander Soare
2024-04-05 18:46:30 +01:00
parent ecc7dd3b17
commit 8d2463f45b
5 changed files with 105 additions and 120 deletions

View File

@@ -3,7 +3,7 @@
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
"""
from collections import deque
import math
import time
from itertools import chain
@@ -22,6 +22,67 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.utils import get_safe_torch_device
# class AbstractPolicy(nn.Module):
# """Base policy which all policies should be derived from.
# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
# documentation for more information.
# Note:
# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
# 1. set the required class attributes:
# - for classes inheriting from `AbstractDataset`: `available_datasets`
# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
# - for classes inheriting from `AbstractPolicy`: `name`
# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
# 3. update variables in `tests/test_available.py` by importing your new class
# """
# name: str | None = None # same name should be used to instantiate the policy in factory.py
# def __init__(self, n_action_steps: int | None):
# """
# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
# adds that dimension.
# """
# super().__init__()
# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
# self.n_action_steps = n_action_steps
# self.clear_action_queue()
# def clear_action_queue(self):
# """This should be called whenever the environment is reset."""
# if self.n_action_steps is not None:
# self._action_queue = deque([], maxlen=self.n_action_steps)
# def forward(self, fn) -> Tensor:
# """Inference step that makes multi-step policies compatible with their single-step environments.
# WARNING: In general, this should not be overriden.
# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
# the subclass doesn't have to.
# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
# the action trajectory horizon and * is the action dimensions.
# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
# """
# if self.n_action_steps is None:
# return self.select_actions(*args, **kwargs)
# if len(self._action_queue) == 0:
# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# # (n_action_steps, batch_size, *), hence the transpose.
# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
# return self._action_queue.popleft()
class ActionChunkingTransformerPolicy(nn.Module):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
@@ -168,14 +229,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
nn.init.xavier_uniform_(p)
@torch.no_grad()
def select_actions(self, batch, *_):
def select_action(self, batch, *_):
# TODO(now): Implement queueing mechanism.
self.eval()
self._preprocess_batch(batch)
# TODO(now): What's up with this 0.182?
action = self.forward(
robot_state=batch["observation.state"] * 0.182, image=batch["observation.images.top"]
robot_state=batch["observation.state"] * 0.182,
image=batch["observation.images.top"],
return_loss=False,
)
if self.cfg.temporal_agg:
@@ -226,7 +289,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
loss = self.compute_loss(batch)
loss = self.forward(batch, return_loss=True)["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -247,44 +310,38 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def compute_loss(self, batch):
loss_dict = self.forward(
robot_state=batch["observation.state"],
image=batch["observation.images.top"],
actions=batch["action"],
)
loss = loss_dict["loss"]
return loss
def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
def forward(self, batch: dict[str, Tensor], return_loss: bool = False):
# TODO(now): Maybe this shouldn't be here?
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
images = normalize(batch["observation.images.top"])
is_training = actions is not None
if is_training: # training time
actions = actions[:, : self.horizon]
if return_loss: # training time
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(
batch["observation.state"], images, batch["action"]
)
a_hat, (mu, log_sigma_x2) = self._forward(robot_state, image, actions)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean()
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {}
loss_dict["l1"] = l1
loss_dict["l1"] = l1_loss
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _ = self._forward(robot_state, image) # no action, sample from prior
action, _ = self._forward(batch["observation.state"], images)
return action
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
@@ -321,7 +378,9 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # (B, D)
)[
0
] # (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.