backup wip

This commit is contained in:
Alexander Soare
2024-04-05 17:38:29 +01:00
parent 9c28ac8aa4
commit 1e71196fe3
7 changed files with 306 additions and 298 deletions

View File

@@ -158,7 +158,7 @@ class AlohaDataset(torch.utils.data.Dataset):
self.data_ids_per_episode = {}
ep_dicts = []
logging.info("Initialize and feed offline buffer")
frame_idx = 0
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
@@ -190,8 +190,14 @@ class AlohaDataset(torch.utils.data.Dataset):
ep_dict[f"observation.images.{cam}"] = image[:-1]
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
assert isinstance(ep_id, int)
self.data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
assert len(self.data_ids_per_episode[ep_id]) == num_frames
ep_dicts.append(ep_dict)
frame_idx += num_frames
self.data_dict = {}
keys = ep_dicts[0].keys()

View File

@@ -59,96 +59,95 @@ def make_dataset(
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
# TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
# (Pdb) stats['observation']['state']['mean']
# tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
# -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
stats["observation", "state", "mean"] = torch.tensor(
[
-0.00740268,
-0.63187766,
1.0356655,
-0.05027218,
-0.46199223,
-0.07467502,
0.47467607,
-0.03615446,
-0.33203387,
0.9038929,
-0.22060776,
-0.31011587,
-0.23484458,
0.6842416,
]
)
# (Pdb) stats['observation']['state']['std']
# tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
# 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
stats["observation", "state", "std"] = torch.tensor(
[
0.01219023,
0.2975381,
0.16728032,
0.04733803,
0.1486037,
0.08788499,
0.31752336,
0.1049916,
0.27933604,
0.18094037,
0.26604933,
0.30466506,
0.5298686,
0.25505227,
]
)
# (Pdb) stats['action']['mean']
# tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
# -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
stats["action"]["mean"] = torch.tensor(
[
-0.00756444,
-0.6281845,
1.0312834,
-0.04664314,
-0.47211358,
-0.074527,
0.37389806,
-0.03718753,
-0.3261143,
0.8997205,
-0.21371077,
-0.31840396,
-0.23360962,
0.551947,
]
)
# (Pdb) stats['action']['std']
# tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
# 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
stats["action"]["std"] = torch.tensor(
[
0.01252818,
0.2957442,
0.16701928,
0.04584508,
0.14833844,
0.08763024,
0.30665937,
0.10600077,
0.27572668,
0.1805853,
0.26304692,
0.30708534,
0.5305411,
0.38381037,
]
)
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
# # TODO(now): These stats are needed to use their pretrained model for sim_transfer_cube_human.
# # (Pdb) stats['observation']['state']['mean']
# # tensor([-0.0071, -0.6293, 1.0351, -0.0517, -0.4642, -0.0754, 0.4751, -0.0373,
# # -0.3324, 0.9034, -0.2258, -0.3127, -0.2412, 0.6866])
# stats["observation", "state", "mean"] = torch.tensor(
# [
# -0.00740268,
# -0.63187766,
# 1.0356655,
# -0.05027218,
# -0.46199223,
# -0.07467502,
# 0.47467607,
# -0.03615446,
# -0.33203387,
# 0.9038929,
# -0.22060776,
# -0.31011587,
# -0.23484458,
# 0.6842416,
# ]
# )
# # (Pdb) stats['observation']['state']['std']
# # tensor([0.0022, 0.0520, 0.0291, 0.0092, 0.0267, 0.0145, 0.0563, 0.0179, 0.0494,
# # 0.0326, 0.0476, 0.0535, 0.0956, 0.0513])
# stats["observation", "state", "std"] = torch.tensor(
# [
# 0.01219023,
# 0.2975381,
# 0.16728032,
# 0.04733803,
# 0.1486037,
# 0.08788499,
# 0.31752336,
# 0.1049916,
# 0.27933604,
# 0.18094037,
# 0.26604933,
# 0.30466506,
# 0.5298686,
# 0.25505227,
# ]
# )
# # (Pdb) stats['action']['mean']
# # tensor([-0.0075, -0.6346, 1.0353, -0.0465, -0.4686, -0.0738, 0.3723, -0.0396,
# # -0.3184, 0.8991, -0.2065, -0.3182, -0.2338, 0.5593])
# stats["action"]["mean"] = torch.tensor(
# [
# -0.00756444,
# -0.6281845,
# 1.0312834,
# -0.04664314,
# -0.47211358,
# -0.074527,
# 0.37389806,
# -0.03718753,
# -0.3261143,
# 0.8997205,
# -0.21371077,
# -0.31840396,
# -0.23360962,
# 0.551947,
# ]
# )
# # (Pdb) stats['action']['std']
# # tensor([0.0023, 0.0514, 0.0290, 0.0086, 0.0263, 0.0143, 0.0593, 0.0185, 0.0510,
# # 0.0328, 0.0478, 0.0531, 0.0945, 0.0794])
# stats["action"]["std"] = torch.tensor(
# [
# 0.01252818,
# 0.2957442,
# 0.16701928,
# 0.04584508,
# 0.14833844,
# 0.08763024,
# 0.30665937,
# 0.10600077,
# 0.27572668,
# 0.1805853,
# 0.26304692,
# 0.30708534,
# 0.5305411,
# 0.38381037,
# ]
# )
# transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) # noqa: F821
transforms = v2.Compose(
[
@@ -173,7 +172,11 @@ def make_dataset(
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
}
else:
delta_timestamps = None
delta_timestamps = {
"observation.images.top": [0],
"observation.state": [0],
"action": [i / clsfunc.fps for i in range(cfg.policy.horizon)],
}
dataset = clsfunc(
dataset_id=cfg.dataset_id,

View File

@@ -19,11 +19,10 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.utils import get_safe_torch_device
class ActionChunkingTransformerPolicy(AbstractPolicy):
class ActionChunkingTransformerPolicy(nn.Module):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@@ -61,205 +60,20 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
"""
name = "act"
_multiple_obs_steps_not_handled_msg = (
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg, device, n_action_steps=1):
"""
TODO(alexander-soare): Add documentation for all parameters.
"""
super().__init__(n_action_steps)
super().__init__()
if getattr(cfg, "n_obs_steps", 1) != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.model = _ActionChunkingTransformer(cfg)
self._create_optimizer()
self.to(self.device)
def _create_optimizer(self):
optimizer_params_dicts = [
{
"params": [
p
for n, p in self.model.named_parameters()
if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.model.named_parameters()
if n.startswith("backbone") and p.requires_grad
],
"lr": self.cfg.lr_backbone,
},
]
self.optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
)
def update(self, replay_buffer, step):
del step
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon)
image = batch["observation", "image", "top"]
image = image[:, 0] # first observation t=0
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
image = image.float()
state = batch["observation", "state"]
state = state[:, 0] # first observation t=0
# batch, qpos_dim
assert state.ndim == 2
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
assert action.shape[1] == horizon
if self.cfg.n_obs_steps > 1:
raise NotImplementedError()
# # keep first n observations of the slice corresponding to t=[-1,0]
# image = image[:, : self.cfg.n_obs_steps]
# state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
start_time = time.time()
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def compute_loss(self, batch):
loss_dict = self._forward(
qpos=batch["obs"]["agent_pos"],
image=batch["obs"]["image"],
actions=batch["action"],
)
loss = loss_dict["loss"]
return loss
@torch.no_grad()
def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
self.eval()
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
obs_dict = {
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# take first predicted action or n first actions
action = action[: self.n_action_steps]
return action
def _forward(self, qpos, image, actions=None):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
is_training = actions is not None
if is_training: # training time
actions = actions[:, : self.model.horizon]
a_hat, (mu, log_sigma_x2) = self.model(qpos, image, actions)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean()
loss_dict = {}
loss_dict["l1"] = l1
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _ = self.model(qpos, image) # no action, sample from prior
return action
# TODO(alexander-soare) move all this code into the policy when we have the policy API established.
class _ActionChunkingTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.camera_names = cfg.camera_names
self.use_vae = cfg.use_vae
self.horizon = cfg.horizon
@@ -326,26 +140,179 @@ class _ActionChunkingTransformer(nn.Module):
self._reset_parameters()
self._create_optimizer()
self.to(self.device)
def _create_optimizer(self):
optimizer_params_dicts = [
{
"params": [
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": self.cfg.lr_backbone,
},
]
self.optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
)
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
@torch.no_grad()
def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
self.eval()
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
obs_dict = {
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# take first predicted action or n first actions
action = action[: self.n_action_steps]
return action
def __call__(self, *args, **kwargs):
# TODO(now): Temporary bridge.
return self.update(*args, **kwargs)
def _preprocess_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""
Expects batch to have (at least):
{
"observation.state": (B, 1, J) tensor of robot states (joint configuration)
"observation.images.top": (B, 1, C, H, W) tensor of images.
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
}
"""
if batch["observation.state"].shape[1] != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
batch["observation.state"] = batch["observation.state"].squeeze(1)
# TODO(alexander-soare): generalize this to multiple images. Note: no squeeze is required for
# "observation.images.top" because then we'd have to unsqueeze to get get the image index dimension.
def update(self, batch, *_):
start_time = time.time()
self._preprocess_batch(batch)
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"update_s": time.time() - start_time,
}
return info
def compute_loss(self, batch):
loss_dict = self.forward(
robot_state=batch["observation.state"],
image=batch["observation.images.top"],
actions=batch["action"],
)
loss = loss_dict["loss"]
return loss
def forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
# TODO(now): Maybe this shouldn't be here?
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
is_training = actions is not None
if is_training: # training time
actions = actions[:, : self.horizon]
a_hat, (mu, log_sigma_x2) = self._forward(robot_state, image, actions)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean()
loss_dict = {}
loss_dict["l1"] = l1
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (-0.5 * (1 + log_sigma_x2 - mu.pow(2) - (log_sigma_x2).exp())).sum(-1).mean()
loss_dict["kl"] = mean_kld
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.cfg.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _ = self._forward(robot_state, image) # no action, sample from prior
return action
def _forward(self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None):
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
image: (B, N, C, H, W) batch of N camera frames.
actions: (B, S, A) batch of actions from the target dataset which must be provided if the
VAE is enabled and the model is in training mode.
Returns:
(B, S, A) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
"""
if self.use_vae and self.training:
assert (
actions is not None
), "actions must be provided when using the variational objective in training mode."
batch_size, _ = robot_state.shape
batch_size = robot_state.shape[0]
# Prepare the latent for input to the transformer encoder.
if self.use_vae and actions is not None:
@@ -428,6 +395,13 @@ class _ActionChunkingTransformer(nn.Module):
return actions, [mu, log_sigma_x2]
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""

View File

@@ -152,7 +152,6 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
data_s = time.time() - start_time
loss = self.diffusion.compute_loss(batch)
loss.backward()