[Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578)

This commit is contained in:
Eugene Mironov
2024-12-23 16:43:55 +07:00
committed by GitHub
parent 7b68bfb73b
commit 70b652f791
7 changed files with 499 additions and 2 deletions

View File

@@ -13,7 +13,7 @@ class ClassifierConfig:
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "microsoft/resnet-50"
device: str = "cuda" if torch.cuda.is_available() else "mps"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
def save_pretrained(self, save_dir):

View File

@@ -22,6 +22,11 @@ class ClassifierOutput:
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})")
class Classifier(
nn.Module,
@@ -69,6 +74,8 @@ class Classifier(
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
@@ -93,6 +100,7 @@ class Classifier(
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
)
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""