Fix input dim (#365)

This commit is contained in:
Alexander Soare
2024-08-19 11:42:32 +01:00
committed by GitHub
parent fc3e545e03
commit 0f6e0f6d74

View File

@@ -289,7 +289,7 @@ class VQBeTModel(nn.Module):
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.output_shapes["action"][0], hidden_channels=[self.config.gpt_input_dim]
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]