tests are passing for aloha/act policies, removes abstract policy
This commit is contained in:
@@ -1,82 +0,0 @@
|
|||||||
from collections import deque
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractPolicy(nn.Module):
|
|
||||||
"""Base policy which all policies should be derived from.
|
|
||||||
|
|
||||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
|
||||||
documentation for more information.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
|
||||||
1. set the required class attributes:
|
|
||||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
|
||||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
|
||||||
- for classes inheriting from `AbstractPolicy`: `name`
|
|
||||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
3. update variables in `tests/test_available.py` by importing your new class
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
|
||||||
|
|
||||||
def __init__(self, n_action_steps: int | None):
|
|
||||||
"""
|
|
||||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
|
||||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
|
||||||
adds that dimension.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
self.clear_action_queue()
|
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
|
||||||
"""One step of the policy's learning algorithm."""
|
|
||||||
raise NotImplementedError("Abstract method")
|
|
||||||
|
|
||||||
def save(self, fp):
|
|
||||||
torch.save(self.state_dict(), fp)
|
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
d = torch.load(fp)
|
|
||||||
self.load_state_dict(d)
|
|
||||||
|
|
||||||
def select_actions(self, observation) -> Tensor:
|
|
||||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
|
||||||
|
|
||||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
|
||||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Abstract method")
|
|
||||||
|
|
||||||
def clear_action_queue(self):
|
|
||||||
"""This should be called whenever the environment is reset."""
|
|
||||||
if self.n_action_steps is not None:
|
|
||||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs) -> Tensor:
|
|
||||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
|
||||||
|
|
||||||
WARNING: In general, this should not be overriden.
|
|
||||||
|
|
||||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
|
||||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
|
||||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
|
||||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
|
||||||
the subclass doesn't have to.
|
|
||||||
|
|
||||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
|
||||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
|
||||||
the action trajectory horizon and * is the action dimensions.
|
|
||||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
|
||||||
"""
|
|
||||||
if self.n_action_steps is None:
|
|
||||||
return self.select_actions(*args, **kwargs)
|
|
||||||
if len(self._action_queue) == 0:
|
|
||||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
|
||||||
# (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
|
||||||
return self._action_queue.popleft()
|
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
|
||||||
from lerobot.common.policies.act.detr_vae import build
|
from lerobot.common.policies.act.detr_vae import build
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.policies.utils import populate_queues
|
||||||
|
|
||||||
|
|
||||||
def build_act_model_and_optimizer(cfg):
|
def build_act_model_and_optimizer(cfg):
|
||||||
@@ -41,75 +42,61 @@ def kl_divergence(mu, logvar):
|
|||||||
return total_kld, dimension_wise_kld, mean_kld
|
return total_kld, dimension_wise_kld, mean_kld
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
class ActionChunkingTransformerPolicy(nn.Module):
|
||||||
name = "act"
|
name = "act"
|
||||||
|
|
||||||
def __init__(self, cfg, device, n_action_steps=1):
|
def __init__(self, cfg, n_obs_steps, n_action_steps):
|
||||||
super().__init__(n_action_steps)
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
if self.n_obs_steps > 1:
|
||||||
|
raise NotImplementedError()
|
||||||
self.n_action_steps = n_action_steps
|
self.n_action_steps = n_action_steps
|
||||||
self.device = get_safe_torch_device(device)
|
|
||||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||||
self.kl_weight = self.cfg.kl_weight
|
self.kl_weight = self.cfg.kl_weight
|
||||||
logging.info(f"KL Weight {self.kl_weight}")
|
logging.info(f"KL Weight {self.kl_weight}")
|
||||||
self.to(self.device)
|
|
||||||
|
|
||||||
def update(self, replay_buffer, step):
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Clear observation and action queues. Should be called on `env.reset()`
|
||||||
|
"""
|
||||||
|
self._queues = {
|
||||||
|
"observation.images.top": deque(maxlen=self.n_obs_steps),
|
||||||
|
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||||
|
"action": deque(maxlen=self.n_action_steps),
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, batch, step):
|
||||||
del step
|
del step
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
num_slices = self.cfg.batch_size
|
image = batch["observation.images.top"]
|
||||||
batch_size = self.cfg.horizon * num_slices
|
# batch, num_cam, channel, height, width
|
||||||
|
image = image.unsqueeze(1)
|
||||||
|
assert image.ndim == 5
|
||||||
|
|
||||||
assert batch_size % self.cfg.horizon == 0
|
state = batch["observation.state"]
|
||||||
assert batch_size % num_slices == 0
|
# batch, qpos_dim
|
||||||
|
assert state.ndim == 2
|
||||||
|
|
||||||
def process_batch(batch, horizon, num_slices):
|
action = batch["action"]
|
||||||
# trajectory t = 64, horizon h = 16
|
# batch, seq, action_dim
|
||||||
# (t h) ... -> t h ...
|
assert action.ndim == 3
|
||||||
batch = batch.reshape(num_slices, horizon)
|
|
||||||
|
|
||||||
image = batch["observation", "image", "top"]
|
preprocessed_batch = {
|
||||||
image = image[:, 0] # first observation t=0
|
"obs": {
|
||||||
# batch, num_cam, channel, height, width
|
"image": image,
|
||||||
image = image.unsqueeze(1)
|
"agent_pos": state,
|
||||||
assert image.ndim == 5
|
},
|
||||||
image = image.float()
|
"action": action,
|
||||||
|
}
|
||||||
state = batch["observation", "state"]
|
|
||||||
state = state[:, 0] # first observation t=0
|
|
||||||
# batch, qpos_dim
|
|
||||||
assert state.ndim == 2
|
|
||||||
|
|
||||||
action = batch["action"]
|
|
||||||
# batch, seq, action_dim
|
|
||||||
assert action.ndim == 3
|
|
||||||
assert action.shape[1] == horizon
|
|
||||||
|
|
||||||
if self.cfg.n_obs_steps > 1:
|
|
||||||
raise NotImplementedError()
|
|
||||||
# # keep first n observations of the slice corresponding to t=[-1,0]
|
|
||||||
# image = image[:, : self.cfg.n_obs_steps]
|
|
||||||
# state = state[:, : self.cfg.n_obs_steps]
|
|
||||||
|
|
||||||
out = {
|
|
||||||
"obs": {
|
|
||||||
"image": image.to(self.device, non_blocking=True),
|
|
||||||
"agent_pos": state.to(self.device, non_blocking=True),
|
|
||||||
},
|
|
||||||
"action": action.to(self.device, non_blocking=True),
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
|
|
||||||
batch = replay_buffer.sample(batch_size)
|
|
||||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
|
||||||
|
|
||||||
data_s = time.time() - start_time
|
data_s = time.time() - start_time
|
||||||
|
|
||||||
loss = self.compute_loss(batch)
|
loss = self.compute_loss(preprocessed_batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
@@ -150,40 +137,52 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_actions(self, observation, step_count):
|
def select_action(self, batch, step):
|
||||||
if observation["image"].shape[0] != 1:
|
assert "observation.images.top" in batch
|
||||||
raise NotImplementedError("Batch size > 1 not handled")
|
assert "observation.state" in batch
|
||||||
|
assert len(batch) == 2
|
||||||
|
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
# TODO(rcadene): remove unused step_count
|
# TODO(rcadene): remove unused step_count
|
||||||
del step_count
|
del step
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
# TODO(rcadene): remove hack
|
if len(self._queues["action"]) == 0:
|
||||||
# add 1 camera dimension
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
|
||||||
|
|
||||||
obs_dict = {
|
if self.n_obs_steps == 1:
|
||||||
"image": observation["image", "top"],
|
# hack to remove the time dimension
|
||||||
"agent_pos": observation["state"],
|
for key in batch:
|
||||||
}
|
assert batch[key].shape[1] == 1
|
||||||
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
batch[key] = batch[key][:, 0]
|
||||||
|
|
||||||
if self.cfg.temporal_agg:
|
actions = self._forward(
|
||||||
# TODO(rcadene): implement temporal aggregation
|
# TODO(rcadene): remove unsqueeze hack to add the "number of cameras" dimension
|
||||||
raise NotImplementedError()
|
image=batch["observation.images.top"].unsqueeze(1),
|
||||||
# all_time_actions[[t], t:t+num_queries] = action
|
qpos=batch["observation.state"],
|
||||||
# actions_for_curr_step = all_time_actions[:, t]
|
)
|
||||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
|
||||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
|
||||||
# k = 0.01
|
|
||||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
|
||||||
# exp_weights = exp_weights / exp_weights.sum()
|
|
||||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
|
||||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
|
||||||
|
|
||||||
# take first predicted action or n first actions
|
if self.cfg.temporal_agg:
|
||||||
action = action[: self.n_action_steps]
|
# TODO(rcadene): implement temporal aggregation
|
||||||
|
raise NotImplementedError()
|
||||||
|
# all_time_actions[[t], t:t+num_queries] = action
|
||||||
|
# actions_for_curr_step = all_time_actions[:, t]
|
||||||
|
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||||
|
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||||
|
# k = 0.01
|
||||||
|
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||||
|
# exp_weights = exp_weights / exp_weights.sum()
|
||||||
|
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||||
|
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# act returns a sequence of `n` actions, but we consider only
|
||||||
|
# the first `n_action_steps` actions subset
|
||||||
|
for i in range(self.n_action_steps):
|
||||||
|
self._queues["action"].append(actions[:, i])
|
||||||
|
|
||||||
|
action = self._queues["action"].popleft()
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
|||||||
@@ -25,10 +25,10 @@ def make_policy(cfg):
|
|||||||
|
|
||||||
policy = ActionChunkingTransformerPolicy(
|
policy = ActionChunkingTransformerPolicy(
|
||||||
cfg.policy,
|
cfg.policy,
|
||||||
cfg.device,
|
n_obs_steps=cfg.policy.n_obs_steps,
|
||||||
n_obs_steps=cfg.n_obs_steps,
|
n_action_steps=cfg.policy.n_action_steps,
|
||||||
n_action_steps=cfg.n_action_steps,
|
|
||||||
)
|
)
|
||||||
|
policy.to(cfg.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
|
|||||||
@@ -150,6 +150,8 @@ class TDMPCPolicy(nn.Module):
|
|||||||
|
|
||||||
t0 = step == 0
|
t0 = step == 0
|
||||||
|
|
||||||
|
self.eval()
|
||||||
|
|
||||||
if len(self._queues["action"]) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
|
|
||||||
@@ -171,7 +173,7 @@ class TDMPCPolicy(nn.Module):
|
|||||||
actions.append(action)
|
actions.append(action)
|
||||||
action = torch.stack(actions)
|
action = torch.stack(actions)
|
||||||
|
|
||||||
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||||
if i in range(self.n_action_steps):
|
if i in range(self.n_action_steps):
|
||||||
self._queues["action"].append(action)
|
self._queues["action"].append(action)
|
||||||
|
|
||||||
|
|||||||
@@ -50,8 +50,12 @@ policy:
|
|||||||
utd: 1
|
utd: 1
|
||||||
|
|
||||||
n_obs_steps: ${n_obs_steps}
|
n_obs_steps: ${n_obs_steps}
|
||||||
|
n_action_steps: ${n_action_steps}
|
||||||
|
|
||||||
temporal_agg: false
|
temporal_agg: false
|
||||||
|
|
||||||
state_dim: ???
|
state_dim: ???
|
||||||
action_dim: ???
|
action_dim: ???
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${horizon})]"
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|||||||
("xarm", "tdmpc", ["policy.mpc=true"]),
|
("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||||
("pusht", "diffusion", []),
|
("pusht", "diffusion", []),
|
||||||
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
|
||||||
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||||
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||||
# TODO(aliberts): xarm not working with diffusion
|
# TODO(aliberts): xarm not working with diffusion
|
||||||
# ("xarm", "diffusion", []),
|
# ("xarm", "diffusion", []),
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user