Added normalization schemes and style checks

This commit is contained in:
Michel Aractingi
2024-12-29 12:51:21 +00:00
committed by AdilZouitine
parent 9dafad15e6
commit 80b86e9bc3
10 changed files with 206 additions and 150 deletions

View File

@@ -23,9 +23,11 @@ class ClassifierOutput:
self.hidden_states = hidden_states
def __repr__(self):
return (f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})")
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier(
@@ -74,7 +76,7 @@ class Classifier(
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None: