diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 87cf59f19..98adce00b 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -350,17 +350,22 @@ class VQBeTModel(nn.Module): # get action features (pass through GPT) features = self.policy(input_tokens) - # len(self.config.input_shapes) is the number of different observation modes. this line gets the index of action prompt tokens. + # len(self.config.input_shapes) is the number of different observation modes. + # this line gets the index of action prompt tokens. historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len( self.config.input_shapes ) # only extract the output tokens at the position of action query: - # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). - # Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional). - features = torch.cat( - [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1 - ) + # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, + # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). + # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). + if len_additional_action_token > 0: + features = torch.cat( + [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1 + ) + else: + features = features[:, historical_act_pred_index] # pass through action head action_head_output = self.action_head(features) # if rollout, VQ-BeT don't calculate loss