[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov
2025-01-06 17:34:00 +07:00
committed by Michel Aractingi
parent e5801f467f
commit d1d6ffd23c
10 changed files with 7780 additions and 15 deletions

View File

@@ -1,7 +1,6 @@
import torch
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
ClassifierOutput,
)
@@ -21,6 +20,8 @@ def test_classifier_output():
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig()
classifier = Classifier(config)
@@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
num_classes = 5
config = ClassifierConfig(num_classes=num_classes)
classifier = Classifier(config)
@@ -60,6 +63,8 @@ def test_multiclass_classifier():
@require_package("transformers")
def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig()
assert config.device == "cpu"
@@ -70,6 +75,8 @@ def test_default_device():
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig(device="meta")
assert config.device == "meta"