From 64425d5e0028fe6b3320d4186cfbd69cebe920f7 Mon Sep 17 00:00:00 2001 From: Seungjae Lee <30570922+jayLEE0301@users.noreply.github.com> Date: Thu, 11 Jul 2024 01:56:11 +0900 Subject: [PATCH] Bug fix: fix error when setting select_target_actions_indices in vqbet (#310) Co-authored-by: Alexander Soare --- lerobot/common/policies/vqbet/modeling_vqbet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 6fb9c5d..058c177 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -298,7 +298,8 @@ class VQBeTModel(nn.Module): # bin prediction head / offset prediction head part of VQ-BeT self.action_head = VQBeTHead(config) - num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1 + # Action tokens for: each observation step, the current action token, and all future action tokens. + num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 self.register_buffer( "select_target_actions_indices", torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),