tests are passing for aloha/act policies, removes abstract policy

This commit is contained in:
Cadene
2024-04-09 03:28:56 +00:00
parent 73dfa3c8e3
commit 6902e01db0
6 changed files with 90 additions and 167 deletions

View File

@@ -1,13 +1,14 @@
import logging
import time
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from torch import nn
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
from lerobot.common.utils import get_safe_torch_device
from lerobot.common.policies.utils import populate_queues
def build_act_model_and_optimizer(cfg):
@@ -41,75 +42,61 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(AbstractPolicy):
class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg, device, n_action_steps=1):
super().__init__(n_action_steps)
def __init__(self, cfg, n_obs_steps, n_action_steps):
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
if self.n_obs_steps > 1:
raise NotImplementedError()
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
def update(self, replay_buffer, step):
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.images.top": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
def forward(self, batch, step):
del step
start_time = time.time()
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
image = batch["observation.images.top"]
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
state = batch["observation.state"]
# batch, qpos_dim
assert state.ndim == 2
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon)
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
image = batch["observation", "image", "top"]
image = image[:, 0] # first observation t=0
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
image = image.float()
state = batch["observation", "state"]
state = state[:, 0] # first observation t=0
# batch, qpos_dim
assert state.ndim == 2
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
assert action.shape[1] == horizon
if self.cfg.n_obs_steps > 1:
raise NotImplementedError()
# # keep first n observations of the slice corresponding to t=[-1,0]
# image = image[:, : self.cfg.n_obs_steps]
# state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
preprocessed_batch = {
"obs": {
"image": image,
"agent_pos": state,
},
"action": action,
}
data_s = time.time() - start_time
loss = self.compute_loss(batch)
loss = self.compute_loss(preprocessed_batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -150,40 +137,52 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
return loss
@torch.no_grad()
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
def select_action(self, batch, step):
assert "observation.images.top" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
# TODO(rcadene): remove unused step_count
del step_count
del step
self.eval()
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
obs_dict = {
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
actions = self._forward(
# TODO(rcadene): remove unsqueeze hack to add the "number of cameras" dimension
image=batch["observation.images.top"].unsqueeze(1),
qpos=batch["observation.state"],
)
# take first predicted action or n first actions
action = action[: self.n_action_steps]
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# act returns a sequence of `n` actions, but we consider only
# the first `n_action_steps` actions subset
for i in range(self.n_action_steps):
self._queues["action"].append(actions[:, i])
action = self._queues["action"].popleft()
return action
def _forward(self, qpos, image, actions=None, is_pad=None):