refactor(config): Move device & amp args to PreTrainedConfig (#812)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Steven Palma
2025-03-06 17:59:28 +01:00
committed by GitHub
parent 10706ed753
commit 5e9473806c
19 changed files with 62 additions and 136 deletions

View File

@@ -52,7 +52,7 @@ from lerobot.common.robot_devices.control_configs import (
from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@@ -184,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
policy = make_policy(policy_cfg, ds_meta=dataset.meta)
out_dir = tmp_path / "logger"
@@ -229,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
device=DEVICE,
use_amp=False,
)
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)