Extend reward classifier for multiple camera views (#626)

This commit is contained in:
Michel Aractingi
2025-01-13 13:57:49 +01:00
parent d1d6ffd23c
commit 181727c0fe
9 changed files with 186 additions and 50 deletions

View File

@@ -13,6 +13,7 @@ class ClassifierConfig:
model_name: str = "microsoft/resnet-50"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
def save_pretrained(self, save_dir):
"""Save config to json file."""

View File

@@ -97,7 +97,7 @@ class Classifier(
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim, self.config.hidden_dim),
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
@@ -130,11 +130,11 @@ class Classifier(
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
def forward(self, x: torch.Tensor) -> ClassifierOutput:
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
"""Forward pass of the classifier."""
# For training, we expect input to be a tensor directly from LeRobotDataset
encoder_output = self._get_encoder_output(x)
logits = self.classifier_head(encoder_output)
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
@@ -142,4 +142,10 @@ class Classifier(
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def predict_reward(self, x):
if self.config.num_classes == 2:
return (self.forward(x).probabilities > 0.5).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)