tests are passing for aloha/act policies, removes abstract policy

This commit is contained in:
Cadene
2024-04-09 03:28:56 +00:00
parent 73dfa3c8e3
commit 6902e01db0
6 changed files with 90 additions and 167 deletions

View File

@@ -1,82 +0,0 @@
from collections import deque
import torch
from torch import Tensor, nn
class AbstractPolicy(nn.Module):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
self.n_action_steps = n_action_steps
self.clear_action_queue()
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def select_actions(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
"""
raise NotImplementedError("Abstract method")
def clear_action_queue(self):
"""This should be called whenever the environment is reset."""
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
"""
if self.n_action_steps is None:
return self.select_actions(*args, **kwargs)
if len(self._action_queue) == 0:
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
# (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
return self._action_queue.popleft()

View File

@@ -1,13 +1,14 @@
import logging
import time
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from torch import nn
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
from lerobot.common.utils import get_safe_torch_device
from lerobot.common.policies.utils import populate_queues
def build_act_model_and_optimizer(cfg):
@@ -41,75 +42,61 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(AbstractPolicy):
class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg, device, n_action_steps=1):
super().__init__(n_action_steps)
def __init__(self, cfg, n_obs_steps, n_action_steps):
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
if self.n_obs_steps > 1:
raise NotImplementedError()
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
def update(self, replay_buffer, step):
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.images.top": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
def forward(self, batch, step):
del step
start_time = time.time()
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
image = batch["observation.images.top"]
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
state = batch["observation.state"]
# batch, qpos_dim
assert state.ndim == 2
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon)
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
image = batch["observation", "image", "top"]
image = image[:, 0] # first observation t=0
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
image = image.float()
state = batch["observation", "state"]
state = state[:, 0] # first observation t=0
# batch, qpos_dim
assert state.ndim == 2
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
assert action.shape[1] == horizon
if self.cfg.n_obs_steps > 1:
raise NotImplementedError()
# # keep first n observations of the slice corresponding to t=[-1,0]
# image = image[:, : self.cfg.n_obs_steps]
# state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
preprocessed_batch = {
"obs": {
"image": image,
"agent_pos": state,
},
"action": action,
}
data_s = time.time() - start_time
loss = self.compute_loss(batch)
loss = self.compute_loss(preprocessed_batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -150,40 +137,52 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
return loss
@torch.no_grad()
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
def select_action(self, batch, step):
assert "observation.images.top" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
# TODO(rcadene): remove unused step_count
del step_count
del step
self.eval()
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
obs_dict = {
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
actions = self._forward(
# TODO(rcadene): remove unsqueeze hack to add the "number of cameras" dimension
image=batch["observation.images.top"].unsqueeze(1),
qpos=batch["observation.state"],
)
# take first predicted action or n first actions
action = action[: self.n_action_steps]
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# act returns a sequence of `n` actions, but we consider only
# the first `n_action_steps` actions subset
for i in range(self.n_action_steps):
self._queues["action"].append(actions[:, i])
action = self._queues["action"].popleft()
return action
def _forward(self, qpos, image, actions=None, is_pad=None):

View File

@@ -25,10 +25,10 @@ def make_policy(cfg):
policy = ActionChunkingTransformerPolicy(
cfg.policy,
cfg.device,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
n_obs_steps=cfg.policy.n_obs_steps,
n_action_steps=cfg.policy.n_action_steps,
)
policy.to(cfg.device)
else:
raise ValueError(cfg.policy.name)

View File

@@ -150,6 +150,8 @@ class TDMPCPolicy(nn.Module):
t0 = step == 0
self.eval()
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
@@ -171,7 +173,7 @@ class TDMPCPolicy(nn.Module):
actions.append(action)
action = torch.stack(actions)
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)

View File

@@ -50,8 +50,12 @@ policy:
utd: 1
n_obs_steps: ${n_obs_steps}
n_action_steps: ${n_action_steps}
temporal_agg: false
state_dim: ???
action_dim: ???
delta_timestamps:
action: "[i / ${fps} for i in range(${horizon})]"

View File

@@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): xarm not working with diffusion
# ("xarm", "diffusion", []),
],