From 6902e01db07e2f27d862166d093c23e24654c900 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 9 Apr 2024 03:28:56 +0000 Subject: [PATCH] tests are passing for aloha/act policies, removes abstract policy --- lerobot/common/policies/abstract.py | 82 ------------- lerobot/common/policies/act/policy.py | 153 ++++++++++++------------ lerobot/common/policies/factory.py | 6 +- lerobot/common/policies/tdmpc/policy.py | 4 +- lerobot/configs/policy/act.yaml | 4 + tests/test_policies.py | 8 +- 6 files changed, 90 insertions(+), 167 deletions(-) delete mode 100644 lerobot/common/policies/abstract.py diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py deleted file mode 100644 index 6dc72be..0000000 --- a/lerobot/common/policies/abstract.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections import deque - -import torch -from torch import Tensor, nn - - -class AbstractPolicy(nn.Module): - """Base policy which all policies should be derived from. - - The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its - documentation for more information. - - Note: - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: - 1. set the required class attributes: - - for classes inheriting from `AbstractDataset`: `available_datasets` - - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - - for classes inheriting from `AbstractPolicy`: `name` - 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) - 3. update variables in `tests/test_available.py` by importing your new class - """ - - name: str | None = None # same name should be used to instantiate the policy in factory.py - - def __init__(self, n_action_steps: int | None): - """ - n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single - action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then - adds that dimension. - """ - super().__init__() - assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute." - self.n_action_steps = n_action_steps - self.clear_action_queue() - - def update(self, replay_buffer, step): - """One step of the policy's learning algorithm.""" - raise NotImplementedError("Abstract method") - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - self.load_state_dict(d) - - def select_actions(self, observation) -> Tensor: - """Select an action (or trajectory of actions) based on an observation during rollout. - - If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of - actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. - """ - raise NotImplementedError("Abstract method") - - def clear_action_queue(self): - """This should be called whenever the environment is reset.""" - if self.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.n_action_steps) - - def forward(self, *args, **kwargs) -> Tensor: - """Inference step that makes multi-step policies compatible with their single-step environments. - - WARNING: In general, this should not be overriden. - - Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit - into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an - observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment - observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that - the subclass doesn't have to. - - This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: - 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is - the action trajectory horizon and * is the action dimensions. - 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. - """ - if self.n_action_steps is None: - return self.select_actions(*args, **kwargs) - if len(self._action_queue) == 0: - # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape - # (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) - return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index ae4f732..4138e91 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -1,13 +1,14 @@ import logging import time +from collections import deque import torch import torch.nn.functional as F # noqa: N812 import torchvision.transforms as transforms +from torch import nn -from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.act.detr_vae import build -from lerobot.common.utils import get_safe_torch_device +from lerobot.common.policies.utils import populate_queues def build_act_model_and_optimizer(cfg): @@ -41,75 +42,61 @@ def kl_divergence(mu, logvar): return total_kld, dimension_wise_kld, mean_kld -class ActionChunkingTransformerPolicy(AbstractPolicy): +class ActionChunkingTransformerPolicy(nn.Module): name = "act" - def __init__(self, cfg, device, n_action_steps=1): - super().__init__(n_action_steps) + def __init__(self, cfg, n_obs_steps, n_action_steps): + super().__init__() self.cfg = cfg + self.n_obs_steps = n_obs_steps + if self.n_obs_steps > 1: + raise NotImplementedError() self.n_action_steps = n_action_steps - self.device = get_safe_torch_device(device) self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.kl_weight = self.cfg.kl_weight logging.info(f"KL Weight {self.kl_weight}") - self.to(self.device) - def update(self, replay_buffer, step): + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.images.top": deque(maxlen=self.n_obs_steps), + "observation.state": deque(maxlen=self.n_obs_steps), + "action": deque(maxlen=self.n_action_steps), + } + + def forward(self, batch, step): del step start_time = time.time() self.train() - num_slices = self.cfg.batch_size - batch_size = self.cfg.horizon * num_slices + image = batch["observation.images.top"] + # batch, num_cam, channel, height, width + image = image.unsqueeze(1) + assert image.ndim == 5 - assert batch_size % self.cfg.horizon == 0 - assert batch_size % num_slices == 0 + state = batch["observation.state"] + # batch, qpos_dim + assert state.ndim == 2 - def process_batch(batch, horizon, num_slices): - # trajectory t = 64, horizon h = 16 - # (t h) ... -> t h ... - batch = batch.reshape(num_slices, horizon) + action = batch["action"] + # batch, seq, action_dim + assert action.ndim == 3 - image = batch["observation", "image", "top"] - image = image[:, 0] # first observation t=0 - # batch, num_cam, channel, height, width - image = image.unsqueeze(1) - assert image.ndim == 5 - image = image.float() - - state = batch["observation", "state"] - state = state[:, 0] # first observation t=0 - # batch, qpos_dim - assert state.ndim == 2 - - action = batch["action"] - # batch, seq, action_dim - assert action.ndim == 3 - assert action.shape[1] == horizon - - if self.cfg.n_obs_steps > 1: - raise NotImplementedError() - # # keep first n observations of the slice corresponding to t=[-1,0] - # image = image[:, : self.cfg.n_obs_steps] - # state = state[:, : self.cfg.n_obs_steps] - - out = { - "obs": { - "image": image.to(self.device, non_blocking=True), - "agent_pos": state.to(self.device, non_blocking=True), - }, - "action": action.to(self.device, non_blocking=True), - } - return out - - batch = replay_buffer.sample(batch_size) - batch = process_batch(batch, self.cfg.horizon, num_slices) + preprocessed_batch = { + "obs": { + "image": image, + "agent_pos": state, + }, + "action": action, + } data_s = time.time() - start_time - loss = self.compute_loss(batch) + loss = self.compute_loss(preprocessed_batch) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( @@ -150,40 +137,52 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): return loss @torch.no_grad() - def select_actions(self, observation, step_count): - if observation["image"].shape[0] != 1: - raise NotImplementedError("Batch size > 1 not handled") + def select_action(self, batch, step): + assert "observation.images.top" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + self._queues = populate_queues(self._queues, batch) # TODO(rcadene): remove unused step_count - del step_count + del step self.eval() - # TODO(rcadene): remove hack - # add 1 camera dimension - observation["image", "top"] = observation["image", "top"].unsqueeze(1) + if len(self._queues["action"]) == 0: + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - obs_dict = { - "image": observation["image", "top"], - "agent_pos": observation["state"], - } - action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"]) + if self.n_obs_steps == 1: + # hack to remove the time dimension + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] - if self.cfg.temporal_agg: - # TODO(rcadene): implement temporal aggregation - raise NotImplementedError() - # all_time_actions[[t], t:t+num_queries] = action - # actions_for_curr_step = all_time_actions[:, t] - # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) - # actions_for_curr_step = actions_for_curr_step[actions_populated] - # k = 0.01 - # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) - # exp_weights = exp_weights / exp_weights.sum() - # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) - # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + actions = self._forward( + # TODO(rcadene): remove unsqueeze hack to add the "number of cameras" dimension + image=batch["observation.images.top"].unsqueeze(1), + qpos=batch["observation.state"], + ) - # take first predicted action or n first actions - action = action[: self.n_action_steps] + if self.cfg.temporal_agg: + # TODO(rcadene): implement temporal aggregation + raise NotImplementedError() + # all_time_actions[[t], t:t+num_queries] = action + # actions_for_curr_step = all_time_actions[:, t] + # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) + # actions_for_curr_step = actions_for_curr_step[actions_populated] + # k = 0.01 + # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + # exp_weights = exp_weights / exp_weights.sum() + # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) + # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + + # act returns a sequence of `n` actions, but we consider only + # the first `n_action_steps` actions subset + for i in range(self.n_action_steps): + self._queues["action"].append(actions[:, i]) + + action = self._queues["action"].popleft() return action def _forward(self, qpos, image, actions=None, is_pad=None): diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 371ab22..8636aa6 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -25,10 +25,10 @@ def make_policy(cfg): policy = ActionChunkingTransformerPolicy( cfg.policy, - cfg.device, - n_obs_steps=cfg.n_obs_steps, - n_action_steps=cfg.n_action_steps, + n_obs_steps=cfg.policy.n_obs_steps, + n_action_steps=cfg.policy.n_action_steps, ) + policy.to(cfg.device) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 2d547f2..942ee9b 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -150,6 +150,8 @@ class TDMPCPolicy(nn.Module): t0 = step == 0 + self.eval() + if len(self._queues["action"]) == 0: batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} @@ -171,7 +173,7 @@ class TDMPCPolicy(nn.Module): actions.append(action) action = torch.stack(actions) - # self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time + # tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time if i in range(self.n_action_steps): self._queues["action"].append(action) diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 9dca436..cf5d750 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -50,8 +50,12 @@ policy: utd: 1 n_obs_steps: ${n_obs_steps} + n_action_steps: ${n_action_steps} temporal_agg: false state_dim: ??? action_dim: ??? + + delta_timestamps: + action: "[i / ${fps} for i in range(${horizon})]" diff --git a/tests/test_policies.py b/tests/test_policies.py index 5d0c0d8..8ccc7c6 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH ("xarm", "tdmpc", ["policy.mpc=true"]), ("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "diffusion", []), - # ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]), - #("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]), - #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]), - #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]), + ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]), + ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]), + ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]), # TODO(aliberts): xarm not working with diffusion # ("xarm", "diffusion", []), ],