[Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578)

This commit is contained in:
Eugene Mironov
2024-12-23 16:43:55 +07:00
committed by AdilZouitine
parent 66268fcf85
commit 6340d9d17c
6 changed files with 346 additions and 1 deletions

View File

@@ -22,6 +22,11 @@ class ClassifierOutput:
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})")
class Classifier(
nn.Module,
@@ -69,6 +74,8 @@ class Classifier(
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
@@ -93,6 +100,7 @@ class Classifier(
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
)
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""