diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 6fb9c5d..058c177 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -298,7 +298,8 @@ class VQBeTModel(nn.Module): # bin prediction head / offset prediction head part of VQ-BeT self.action_head = VQBeTHead(config) - num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1 + # Action tokens for: each observation step, the current action token, and all future action tokens. + num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 self.register_buffer( "select_target_actions_indices", torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),