diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index bc12dfa2..a73acb4f 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -289,7 +289,7 @@ class VQBeTModel(nn.Module): # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. self.state_projector = MLP( - config.output_shapes["action"][0], hidden_channels=[self.config.gpt_input_dim] + config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim] ) self.rgb_feature_projector = MLP( self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]