forked from tangger/lerobot
[Port HIL_SERL] Final fixes for the Reward Classifier (#598)
This commit is contained in:
committed by
Michel Aractingi
parent
e5801f467f
commit
d1d6ffd23c
@@ -151,9 +151,9 @@ def test_validate():
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
|
||||
def test_resume_function(
|
||||
mock_make_policy,
|
||||
mock_get_model,
|
||||
mock_dataset,
|
||||
mock_logger,
|
||||
mock_get_last_pretrained_model_dir,
|
||||
@@ -168,7 +168,7 @@ def test_resume_function(
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
cfg = compose(
|
||||
config_name="reward_classifier",
|
||||
config_name="hilserl_classifier",
|
||||
overrides=[
|
||||
"device=cpu",
|
||||
"seed=42",
|
||||
@@ -211,7 +211,7 @@ def test_resume_function(
|
||||
|
||||
# Instantiate the model and set make_policy to return it
|
||||
model = make_dummy_model()
|
||||
mock_make_policy.return_value = model
|
||||
mock_get_model.return_value = model
|
||||
|
||||
# Call train
|
||||
train(cfg)
|
||||
|
||||
Reference in New Issue
Block a user