Use HF Papers (#1120)
This commit is contained in:
committed by
GitHub
parent
2de93a8000
commit
edfebd522c
@@ -162,7 +162,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
@@ -185,7 +185,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||
(https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||
@@ -387,7 +387,7 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
||||
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://huggingface.co/papers/2206.11251).
|
||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
if len_additional_action_token > 0:
|
||||
features = torch.cat(
|
||||
@@ -824,8 +824,8 @@ class VqVae(nn.Module):
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://huggingface.co/papers/2403.03181)
|
||||
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://huggingface.co/papers/2403.03181)
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
with torch.no_grad():
|
||||
state_rep = self.encoder(state)
|
||||
@@ -838,7 +838,7 @@ class VqVae(nn.Module):
|
||||
return state_vq, vq_code
|
||||
|
||||
def vqvae_forward(self, state):
|
||||
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
|
||||
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://huggingface.co/papers/2403.03181).
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
|
||||
state_rep = self.encoder(state)
|
||||
|
||||
@@ -336,7 +336,7 @@ class ResidualVQ(nn.Module):
|
||||
"""
|
||||
Residual VQ is composed of multiple VectorQuantize layers.
|
||||
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
Follows Algorithm 1. in https://huggingface.co/papers/2107.03312
|
||||
"Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is
|
||||
passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional
|
||||
Nq -1 vector quantizers, as described in Algorithm 1."
|
||||
@@ -1006,7 +1006,7 @@ def gumbel_sample(
|
||||
if not straight_through or temperature <= 0.0 or not training:
|
||||
return ind, one_hot
|
||||
|
||||
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
|
||||
# use reinmax for better second-order accuracy - https://huggingface.co/papers/2304.08612
|
||||
# algorithm 2
|
||||
|
||||
if reinmax:
|
||||
@@ -1156,7 +1156,7 @@ def batched_embedding(indices, embeds):
|
||||
|
||||
|
||||
def orthogonal_loss_fn(t):
|
||||
# eq (2) from https://arxiv.org/abs/2112.00384
|
||||
# eq (2) from https://huggingface.co/papers/2112.00384
|
||||
h, n = t.shape[:2]
|
||||
normed_codes = F.normalize(t, p=2, dim=-1)
|
||||
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
|
||||
|
||||
Reference in New Issue
Block a user