[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov
2025-01-06 17:34:00 +07:00
committed by Michel Aractingi
parent e5801f467f
commit d1d6ffd23c
10 changed files with 7780 additions and 15 deletions

View File

@@ -1,7 +1,6 @@
import torch
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
ClassifierConfig,
ClassifierOutput,
)
@@ -21,6 +20,8 @@ def test_classifier_output():
@require_package("transformers")
def test_binary_classifier_with_default_params():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig()
classifier = Classifier(config)
@@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
def test_multiclass_classifier():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
num_classes = 5
config = ClassifierConfig(num_classes=num_classes)
classifier = Classifier(config)
@@ -60,6 +63,8 @@ def test_multiclass_classifier():
@require_package("transformers")
def test_default_device():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig()
assert config.device == "cpu"
@@ -70,6 +75,8 @@ def test_default_device():
@require_package("transformers")
def test_explicit_device_setup():
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
config = ClassifierConfig(device="meta")
assert config.device == "meta"

View File

@@ -151,9 +151,9 @@ def test_validate():
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
def test_resume_function(
mock_make_policy,
mock_get_model,
mock_dataset,
mock_logger,
mock_get_last_pretrained_model_dir,
@@ -168,7 +168,7 @@ def test_resume_function(
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
cfg = compose(
config_name="reward_classifier",
config_name="hilserl_classifier",
overrides=[
"device=cpu",
"seed=42",
@@ -211,7 +211,7 @@ def test_resume_function(
# Instantiate the model and set make_policy to return it
model = make_dummy_model()
mock_make_policy.return_value = model
mock_get_model.return_value = model
# Call train
train(cfg)