[Port HIL_SERL] Final fixes for the Reward Classifier (#598)
This commit is contained in:
committed by
Michel Aractingi
parent
e5801f467f
commit
d1d6ffd23c
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
ClassifierConfig,
|
||||
ClassifierOutput,
|
||||
)
|
||||
@@ -21,6 +20,8 @@ def test_classifier_output():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
classifier = Classifier(config)
|
||||
|
||||
@@ -40,6 +41,8 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = ClassifierConfig(num_classes=num_classes)
|
||||
classifier = Classifier(config)
|
||||
@@ -60,6 +63,8 @@ def test_multiclass_classifier():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
@@ -70,6 +75,8 @@ def test_default_device():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig(device="meta")
|
||||
assert config.device == "meta"
|
||||
|
||||
|
||||
@@ -151,9 +151,9 @@ def test_validate():
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
|
||||
def test_resume_function(
|
||||
mock_make_policy,
|
||||
mock_get_model,
|
||||
mock_dataset,
|
||||
mock_logger,
|
||||
mock_get_last_pretrained_model_dir,
|
||||
@@ -168,7 +168,7 @@ def test_resume_function(
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
cfg = compose(
|
||||
config_name="reward_classifier",
|
||||
config_name="hilserl_classifier",
|
||||
overrides=[
|
||||
"device=cpu",
|
||||
"seed=42",
|
||||
@@ -211,7 +211,7 @@ def test_resume_function(
|
||||
|
||||
# Instantiate the model and set make_policy to return it
|
||||
model = make_dummy_model()
|
||||
mock_make_policy.return_value = model
|
||||
mock_get_model.return_value = model
|
||||
|
||||
# Call train
|
||||
train(cfg)
|
||||
|
||||
Reference in New Issue
Block a user