backup wip

This commit is contained in:
Alexander Soare
2024-04-02 19:11:53 +01:00
parent 11cbf1bea1
commit 2b928eedd4
8 changed files with 314 additions and 37 deletions

View File

@@ -42,9 +42,28 @@ def kl_divergence(mu, logvar):
class ActionChunkingTransformerPolicy(AbstractPolicy):
"""
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
(https://arxiv.org/abs/2304.13705).
"""
name = "act"
def __init__(self, cfg, device, n_action_steps=1):
"""
Args:
vae: Whether to use the variational objective. TODO(now): Give more details.
temporal_agg: Whether to do temporal aggregation. For each timestep during rollout, the action
returned as an exponential moving average of previously generated actions for that timestep.
n_obs_steps: Number of time steps worth of observation to use as input.
horizon: The number of actions to generate in one forward pass.
kl_weight: Weight for KL divergence. Defaults to None. Only applicable when using the variational
objective.
batch_size: Training batch size.
grad_clip_norm: Optionally clip the gradients to have this value as the norm at most. Defaults to
None meaning gradient clipping is not applied.
lr: Learning rate.
"""
super().__init__(n_action_steps)
self.cfg = cfg
self.n_action_steps = n_action_steps
@@ -57,8 +76,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
def update(self, replay_buffer, step):
del step
start_time = time.time()
self.train()
num_slices = self.cfg.batch_size
@@ -103,11 +120,14 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
"action": action.to(self.device, non_blocking=True),
}
return out
start_time = time.time()
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
print(data_s)
loss = self.compute_loss(batch)
loss.backward()
@@ -151,9 +171,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
@torch.no_grad()
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
# TODO(rcadene): remove unused step_count
del step_count
@@ -167,7 +184,17 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
# qpos = obs_dict["agent_pos"]
# img = obs_dict["image"]
# qpos_ = torch.load('/tmp/qpos.pth')
# img_ = torch.load('/tmp/curr_image.pth')
# out_ = torch.load('/tmp/out.pth')
# import cv2, numpy as np
# cv2.imwrite("ours.png", (obs_dict["image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
# cv2.imwrite("theirs.png", (img_[0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
# out = self._forward(qpos_, img_)
# breakpoint()
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
@@ -197,6 +224,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
if is_pad is not None:
is_pad = is_pad[:, : self.model.num_queries]
breakpoint()
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")