[Port HIL-SERL] Final fixes for reward classifier (#1067)

Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Michel Aractingi
2025-05-05 11:33:09 +02:00
committed by GitHub
parent 6fa7df35df
commit 5998203a33
5 changed files with 845 additions and 168 deletions

View File

@@ -40,13 +40,13 @@ def test_binary_classifier_with_default_params():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 224, 224)),
"observation.image": torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 224, 224])
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size])
output = classifier.predict(images)
@@ -56,7 +56,7 @@ def test_binary_classifier_with_default_params():
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 512])
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@@ -79,13 +79,13 @@ def test_multiclass_classifier():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 224, 224)),
"observation.image": torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 224, 224])
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size, num_classes])
output = classifier.predict(images)
@@ -95,7 +95,7 @@ def test_multiclass_classifier():
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 512])
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"