[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov
2025-01-06 17:34:00 +07:00
committed by Michel Aractingi
parent e5801f467f
commit d1d6ffd23c
10 changed files with 7780 additions and 15 deletions

View File

@@ -151,9 +151,9 @@ def test_validate():
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
def test_resume_function(
mock_make_policy,
mock_get_model,
mock_dataset,
mock_logger,
mock_get_last_pretrained_model_dir,
@@ -168,7 +168,7 @@ def test_resume_function(
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
cfg = compose(
config_name="reward_classifier",
config_name="hilserl_classifier",
overrides=[
"device=cpu",
"seed=42",
@@ -211,7 +211,7 @@ def test_resume_function(
# Instantiate the model and set make_policy to return it
model = make_dummy_model()
mock_make_policy.return_value = model
mock_get_model.return_value = model
# Call train
train(cfg)