[Port HIL-SERL] Final fixes for reward classifier (#1067)
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -409,7 +409,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
|
|
||||||
names = ft["names"]
|
names = ft["names"]
|
||||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
elif key == "observation.environment_state":
|
elif key == "observation.environment_state":
|
||||||
type = FeatureType.ENV
|
type = FeatureType.ENV
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ class RewardClassifierConfig(PreTrainedConfig):
|
|||||||
name: str = "reward_classifier"
|
name: str = "reward_classifier"
|
||||||
num_classes: int = 2
|
num_classes: int = 2
|
||||||
hidden_dim: int = 256
|
hidden_dim: int = 256
|
||||||
|
latent_dim: int = 256
|
||||||
|
image_embedding_pooling_dim: int = 8
|
||||||
dropout_rate: float = 0.1
|
dropout_rate: float = 0.1
|
||||||
model_name: str = "helper2424/resnet10"
|
model_name: str = "helper2424/resnet10"
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
|
|||||||
@@ -34,6 +34,59 @@ class ClassifierOutput:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialLearnedEmbeddings(nn.Module):
|
||||||
|
def __init__(self, height, width, channel, num_features=8):
|
||||||
|
"""
|
||||||
|
PyTorch implementation of learned spatial embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height: Spatial height of input features
|
||||||
|
width: Spatial width of input features
|
||||||
|
channel: Number of input channels
|
||||||
|
num_features: Number of output embedding dimensions
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.channel = channel
|
||||||
|
self.num_features = num_features
|
||||||
|
|
||||||
|
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||||
|
|
||||||
|
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
"""
|
||||||
|
Forward pass for spatial embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
|
||||||
|
Returns:
|
||||||
|
Output tensor of shape [B, C*F] or [C*F] if no batch
|
||||||
|
"""
|
||||||
|
|
||||||
|
features = features.last_hidden_state
|
||||||
|
|
||||||
|
original_shape = features.shape
|
||||||
|
if features.dim() == 3:
|
||||||
|
features = features.unsqueeze(0) # Add batch dim
|
||||||
|
|
||||||
|
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
|
||||||
|
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
|
||||||
|
|
||||||
|
# Element-wise multiplication and spatial reduction
|
||||||
|
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
|
||||||
|
|
||||||
|
# Reshape to combine channel and feature dimensions
|
||||||
|
output = output.view(output.size(0), -1) # [B, C*F]
|
||||||
|
|
||||||
|
# Remove batch dim
|
||||||
|
if len(original_shape) == 3:
|
||||||
|
output = output.squeeze(0)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Classifier(PreTrainedPolicy):
|
class Classifier(PreTrainedPolicy):
|
||||||
"""Image classifier built on top of a pre-trained encoder."""
|
"""Image classifier built on top of a pre-trained encoder."""
|
||||||
|
|
||||||
@@ -78,6 +131,18 @@ class Classifier(PreTrainedPolicy):
|
|||||||
self._setup_cnn_backbone()
|
self._setup_cnn_backbone()
|
||||||
|
|
||||||
self._freeze_encoder()
|
self._freeze_encoder()
|
||||||
|
|
||||||
|
# Extract image keys from input_features
|
||||||
|
self.image_keys = [
|
||||||
|
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.is_cnn:
|
||||||
|
self.encoders = nn.ModuleDict()
|
||||||
|
for image_key in self.image_keys:
|
||||||
|
encoder = self._create_single_encoder()
|
||||||
|
self.encoders[image_key] = encoder
|
||||||
|
|
||||||
self._build_classifier_head()
|
self._build_classifier_head()
|
||||||
|
|
||||||
def _setup_cnn_backbone(self):
|
def _setup_cnn_backbone(self):
|
||||||
@@ -95,11 +160,28 @@ class Classifier(PreTrainedPolicy):
|
|||||||
for param in self.encoder.parameters():
|
for param in self.encoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _create_single_encoder(self):
|
||||||
|
encoder = nn.Sequential(
|
||||||
|
self.encoder,
|
||||||
|
SpatialLearnedEmbeddings(
|
||||||
|
height=4,
|
||||||
|
width=4,
|
||||||
|
channel=self.feature_dim,
|
||||||
|
num_features=self.config.image_embedding_pooling_dim,
|
||||||
|
),
|
||||||
|
nn.Dropout(self.config.dropout_rate),
|
||||||
|
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
|
||||||
|
nn.LayerNorm(self.config.latent_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
def _build_classifier_head(self) -> None:
|
def _build_classifier_head(self) -> None:
|
||||||
"""Initialize the classifier head architecture."""
|
"""Initialize the classifier head architecture."""
|
||||||
# Get input dimension based on model type
|
# Get input dimension based on model type
|
||||||
if self.is_cnn:
|
if self.is_cnn:
|
||||||
input_dim = self.feature_dim
|
input_dim = self.config.latent_dim
|
||||||
else: # Transformer models
|
else: # Transformer models
|
||||||
if hasattr(self.encoder.config, "hidden_size"):
|
if hasattr(self.encoder.config, "hidden_size"):
|
||||||
input_dim = self.encoder.config.hidden_size
|
input_dim = self.encoder.config.hidden_size
|
||||||
@@ -117,26 +199,20 @@ class Classifier(PreTrainedPolicy):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
|
||||||
"""Extract the appropriate output from the encoder."""
|
"""Extract the appropriate output from the encoder."""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.is_cnn:
|
if self.is_cnn:
|
||||||
# The HF ResNet applies pooling internally
|
# The HF ResNet applies pooling internally
|
||||||
outputs = self.encoder(x)
|
outputs = self.encoders[image_key](x)
|
||||||
# Get pooled output directly
|
return outputs
|
||||||
features = outputs.pooler_output
|
|
||||||
|
|
||||||
if features.dim() > 2:
|
|
||||||
features = features.squeeze(-1).squeeze(-1)
|
|
||||||
return features
|
|
||||||
else: # Transformer models
|
else: # Transformer models
|
||||||
outputs = self.encoder(x)
|
outputs = self.encoder(x)
|
||||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
|
||||||
return outputs.pooler_output
|
|
||||||
return outputs.last_hidden_state[:, 0, :]
|
return outputs.last_hidden_state[:, 0, :]
|
||||||
|
|
||||||
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
||||||
"""Extract image tensors and label tensors from batch."""
|
"""Extract image tensors and label tensors from batch."""
|
||||||
|
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||||
labels = batch["next.reward"]
|
labels = batch["next.reward"]
|
||||||
|
|
||||||
@@ -144,7 +220,9 @@ class Classifier(PreTrainedPolicy):
|
|||||||
|
|
||||||
def predict(self, xs: list) -> ClassifierOutput:
|
def predict(self, xs: list) -> ClassifierOutput:
|
||||||
"""Forward pass of the classifier for inference."""
|
"""Forward pass of the classifier for inference."""
|
||||||
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
|
encoder_outputs = torch.hstack(
|
||||||
|
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
|
||||||
|
)
|
||||||
logits = self.classifier_head(encoder_outputs)
|
logits = self.classifier_head(encoder_outputs)
|
||||||
|
|
||||||
if self.config.num_classes == 2:
|
if self.config.num_classes == 2:
|
||||||
@@ -192,8 +270,14 @@ class Classifier(PreTrainedPolicy):
|
|||||||
return loss, output_dict
|
return loss, output_dict
|
||||||
|
|
||||||
def predict_reward(self, batch, threshold=0.5):
|
def predict_reward(self, batch, threshold=0.5):
|
||||||
"""Legacy method for compatibility."""
|
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||||
|
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
|
# Extract images from batch dict
|
||||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||||
|
|
||||||
if self.config.num_classes == 2:
|
if self.config.num_classes == 2:
|
||||||
probs = self.predict(images).probabilities
|
probs = self.predict(images).probabilities
|
||||||
logging.debug(f"Predicted reward images: {probs}")
|
logging.debug(f"Predicted reward images: {probs}")
|
||||||
@@ -201,13 +285,9 @@ class Classifier(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self):
|
||||||
"""Return optimizer parameters for the policy."""
|
"""Return optimizer parameters for the policy."""
|
||||||
return {
|
return self.parameters()
|
||||||
"params": self.parameters(),
|
|
||||||
"lr": getattr(self.config, "learning_rate", 1e-4),
|
|
||||||
"weight_decay": getattr(self.config, "weight_decay", 0.01),
|
|
||||||
}
|
|
||||||
|
|
||||||
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -40,13 +40,13 @@ def test_binary_classifier_with_default_params():
|
|||||||
batch_size = 10
|
batch_size = 10
|
||||||
|
|
||||||
input = {
|
input = {
|
||||||
"observation.image": torch.rand((batch_size, 3, 224, 224)),
|
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||||
}
|
}
|
||||||
|
|
||||||
images, labels = classifier.extract_images_and_labels(input)
|
images, labels = classifier.extract_images_and_labels(input)
|
||||||
assert len(images) == 1
|
assert len(images) == 1
|
||||||
assert images[0].shape == torch.Size([batch_size, 3, 224, 224])
|
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||||
assert labels.shape == torch.Size([batch_size])
|
assert labels.shape == torch.Size([batch_size])
|
||||||
|
|
||||||
output = classifier.predict(images)
|
output = classifier.predict(images)
|
||||||
@@ -56,7 +56,7 @@ def test_binary_classifier_with_default_params():
|
|||||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||||
assert output.probabilities.shape == torch.Size([batch_size])
|
assert output.probabilities.shape == torch.Size([batch_size])
|
||||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||||
assert output.hidden_states.shape == torch.Size([batch_size, 512])
|
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||||
|
|
||||||
|
|
||||||
@@ -79,13 +79,13 @@ def test_multiclass_classifier():
|
|||||||
batch_size = 10
|
batch_size = 10
|
||||||
|
|
||||||
input = {
|
input = {
|
||||||
"observation.image": torch.rand((batch_size, 3, 224, 224)),
|
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||||
"next.reward": torch.rand((batch_size, num_classes)),
|
"next.reward": torch.rand((batch_size, num_classes)),
|
||||||
}
|
}
|
||||||
|
|
||||||
images, labels = classifier.extract_images_and_labels(input)
|
images, labels = classifier.extract_images_and_labels(input)
|
||||||
assert len(images) == 1
|
assert len(images) == 1
|
||||||
assert images[0].shape == torch.Size([batch_size, 3, 224, 224])
|
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||||
assert labels.shape == torch.Size([batch_size, num_classes])
|
assert labels.shape == torch.Size([batch_size, num_classes])
|
||||||
|
|
||||||
output = classifier.predict(images)
|
output = classifier.predict(images)
|
||||||
@@ -95,7 +95,7 @@ def test_multiclass_classifier():
|
|||||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||||
assert output.hidden_states.shape == torch.Size([batch_size, 512])
|
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user