forked from tangger/lerobot
[Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578)
This commit is contained in:
committed by
AdilZouitine
parent
66268fcf85
commit
6340d9d17c
@@ -22,6 +22,11 @@ class ClassifierOutput:
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})")
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
@@ -69,6 +74,8 @@ class Classifier(
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
self.encoder = self.encoder.to(self.config.device)
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
@@ -93,6 +100,7 @@ class Classifier(
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
|
||||
Reference in New Issue
Block a user