|
|
|
|
@@ -3,9 +3,10 @@
|
|
|
|
|
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
|
|
|
|
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
|
|
|
|
"""
|
|
|
|
|
from collections import deque
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
import time
|
|
|
|
|
from collections import deque
|
|
|
|
|
from itertools import chain
|
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
|
|
@@ -22,67 +23,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|
|
|
|
from lerobot.common.utils import get_safe_torch_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# class AbstractPolicy(nn.Module):
|
|
|
|
|
# """Base policy which all policies should be derived from.
|
|
|
|
|
|
|
|
|
|
# The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
|
|
|
|
# documentation for more information.
|
|
|
|
|
|
|
|
|
|
# Note:
|
|
|
|
|
# When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
|
|
|
|
# 1. set the required class attributes:
|
|
|
|
|
# - for classes inheriting from `AbstractDataset`: `available_datasets`
|
|
|
|
|
# - for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
|
|
|
|
# - for classes inheriting from `AbstractPolicy`: `name`
|
|
|
|
|
# 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
|
|
|
# 3. update variables in `tests/test_available.py` by importing your new class
|
|
|
|
|
# """
|
|
|
|
|
|
|
|
|
|
# name: str | None = None # same name should be used to instantiate the policy in factory.py
|
|
|
|
|
|
|
|
|
|
# def __init__(self, n_action_steps: int | None):
|
|
|
|
|
# """
|
|
|
|
|
# n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
|
|
|
|
# action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
|
|
|
|
# adds that dimension.
|
|
|
|
|
# """
|
|
|
|
|
# super().__init__()
|
|
|
|
|
# assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
|
|
|
|
# self.n_action_steps = n_action_steps
|
|
|
|
|
# self.clear_action_queue()
|
|
|
|
|
|
|
|
|
|
# def clear_action_queue(self):
|
|
|
|
|
# """This should be called whenever the environment is reset."""
|
|
|
|
|
# if self.n_action_steps is not None:
|
|
|
|
|
# self._action_queue = deque([], maxlen=self.n_action_steps)
|
|
|
|
|
|
|
|
|
|
# def forward(self, fn) -> Tensor:
|
|
|
|
|
# """Inference step that makes multi-step policies compatible with their single-step environments.
|
|
|
|
|
|
|
|
|
|
# WARNING: In general, this should not be overriden.
|
|
|
|
|
|
|
|
|
|
# Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
|
|
|
|
# into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
|
|
|
|
# observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
|
|
|
|
# observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
|
|
|
|
# the subclass doesn't have to.
|
|
|
|
|
|
|
|
|
|
# This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
|
|
|
|
# 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
|
|
|
|
# the action trajectory horizon and * is the action dimensions.
|
|
|
|
|
# 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
|
|
|
|
# """
|
|
|
|
|
# if self.n_action_steps is None:
|
|
|
|
|
# return self.select_actions(*args, **kwargs)
|
|
|
|
|
# if len(self._action_queue) == 0:
|
|
|
|
|
# # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
|
|
|
|
# # (n_action_steps, batch_size, *), hence the transpose.
|
|
|
|
|
# self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
|
|
|
|
# return self._action_queue.popleft()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActionChunkingTransformerPolicy(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
|
|
|
|
@@ -228,18 +168,30 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|
|
|
|
if p.dim() > 1:
|
|
|
|
|
nn.init.xavier_uniform_(p)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def select_action(self, batch, *_):
|
|
|
|
|
# TODO(now): Implement queueing mechanism.
|
|
|
|
|
self.eval()
|
|
|
|
|
self._preprocess_batch(batch)
|
|
|
|
|
def reset(self):
|
|
|
|
|
"""This should be called whenever the environment is reset."""
|
|
|
|
|
if self.n_action_steps is not None:
|
|
|
|
|
self._action_queue = deque([], maxlen=self.n_action_steps)
|
|
|
|
|
|
|
|
|
|
# TODO(now): What's up with this 0.182?
|
|
|
|
|
action = self.forward(
|
|
|
|
|
robot_state=batch["observation.state"] * 0.182,
|
|
|
|
|
image=batch["observation.images.top"],
|
|
|
|
|
return_loss=False,
|
|
|
|
|
)
|
|
|
|
|
def select_action(self, batch: dict[str, Tensor], *_):
|
|
|
|
|
"""
|
|
|
|
|
This method wraps `select_actions` in order to return one action at a time for execution in the
|
|
|
|
|
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
|
|
|
|
queue is empty.
|
|
|
|
|
"""
|
|
|
|
|
if len(self._action_queue) == 0:
|
|
|
|
|
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
|
|
|
|
# (n_action_steps, batch_size, *), hence the transpose.
|
|
|
|
|
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
|
|
|
|
|
return self._action_queue.popleft()
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def select_actions(self, batch: dict[str, Tensor]):
|
|
|
|
|
"""Use the action chunking transformer to generate a sequence of actions."""
|
|
|
|
|
self.eval()
|
|
|
|
|
self._preprocess_batch(batch, add_obs_steps_dim=True)
|
|
|
|
|
|
|
|
|
|
action = self.forward(batch, return_loss=False)
|
|
|
|
|
|
|
|
|
|
if self.cfg.temporal_agg:
|
|
|
|
|
# TODO(rcadene): implement temporal aggregation
|
|
|
|
|
@@ -257,25 +209,37 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|
|
|
|
return action[: self.n_action_steps]
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
# TODO(now): Temporary bridge.
|
|
|
|
|
# TODO(now): Temporary bridge until we know what to do about the `update` method.
|
|
|
|
|
return self.update(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def _preprocess_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
|
|
|
def _preprocess_batch(
|
|
|
|
|
self, batch: dict[str, Tensor], add_obs_steps_dim: bool = False
|
|
|
|
|
) -> dict[str, Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Expects batch to have (at least):
|
|
|
|
|
This function expects `batch` to have (at least):
|
|
|
|
|
{
|
|
|
|
|
"observation.state": (B, 1, J) tensor of robot states (joint configuration)
|
|
|
|
|
|
|
|
|
|
"observation.images.top": (B, 1, C, H, W) tensor of images.
|
|
|
|
|
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
|
|
|
|
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
|
|
|
|
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
|
|
|
|
|
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
if add_obs_steps_dim:
|
|
|
|
|
# Add a dimension for the observations steps. Since n_obs_steps > 1 is not supported right now,
|
|
|
|
|
# this just amounts to an unsqueeze.
|
|
|
|
|
for k in batch:
|
|
|
|
|
if k.startswith("observation."):
|
|
|
|
|
batch[k] = batch[k].unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
if batch["observation.state"].shape[1] != 1:
|
|
|
|
|
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
|
|
|
|
batch["observation.state"] = batch["observation.state"].squeeze(1)
|
|
|
|
|
# TODO(alexander-soare): generalize this to multiple images. Note: no squeeze is required for
|
|
|
|
|
# "observation.images.top" because then we'd have to unsqueeze to get get the image index dimension.
|
|
|
|
|
# TODO(alexander-soare): generalize this to multiple images.
|
|
|
|
|
assert (
|
|
|
|
|
sum(k.startswith("observation.images.") and not k.endswith("is_pad") for k in batch) == 1
|
|
|
|
|
), "ACT only handles one image for now."
|
|
|
|
|
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
|
|
|
|
# the image index dimension.
|
|
|
|
|
|
|
|
|
|
def update(self, batch, *_):
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
@@ -378,9 +342,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|
|
|
|
# Forward pass through VAE encoder and sample the latent with the reparameterization trick.
|
|
|
|
|
cls_token_out = self.vae_encoder(
|
|
|
|
|
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
|
|
|
|
)[
|
|
|
|
|
0
|
|
|
|
|
] # (B, D)
|
|
|
|
|
)[0] # (B, D)
|
|
|
|
|
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
|
|
|
|
mu = latent_pdf_params[:, : self.latent_dim]
|
|
|
|
|
# This is 2log(sigma). Done this way to match the original implementation.
|
|
|
|
|
|