backup wip
This commit is contained in:
@@ -42,9 +42,28 @@ def kl_divergence(mu, logvar):
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
"""
|
||||
Action Chunking Transformer as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware
|
||||
(https://arxiv.org/abs/2304.13705).
|
||||
"""
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
"""
|
||||
Args:
|
||||
vae: Whether to use the variational objective. TODO(now): Give more details.
|
||||
temporal_agg: Whether to do temporal aggregation. For each timestep during rollout, the action
|
||||
returned as an exponential moving average of previously generated actions for that timestep.
|
||||
n_obs_steps: Number of time steps worth of observation to use as input.
|
||||
horizon: The number of actions to generate in one forward pass.
|
||||
kl_weight: Weight for KL divergence. Defaults to None. Only applicable when using the variational
|
||||
objective.
|
||||
batch_size: Training batch size.
|
||||
grad_clip_norm: Optionally clip the gradients to have this value as the norm at most. Defaults to
|
||||
None meaning gradient clipping is not applied.
|
||||
lr: Learning rate.
|
||||
"""
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
@@ -57,8 +76,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
@@ -103,11 +120,14 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
print(data_s)
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
@@ -151,9 +171,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
@@ -167,7 +184,17 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
"image": observation["image", "top"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||
# qpos = obs_dict["agent_pos"]
|
||||
# img = obs_dict["image"]
|
||||
# qpos_ = torch.load('/tmp/qpos.pth')
|
||||
# img_ = torch.load('/tmp/curr_image.pth')
|
||||
# out_ = torch.load('/tmp/out.pth')
|
||||
# import cv2, numpy as np
|
||||
# cv2.imwrite("ours.png", (obs_dict["image"][0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
||||
# cv2.imwrite("theirs.png", (img_[0, 0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
|
||||
# out = self._forward(qpos_, img_)
|
||||
# breakpoint()
|
||||
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
@@ -197,6 +224,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
breakpoint()
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
|
||||
Reference in New Issue
Block a user