Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
@@ -29,7 +29,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
@@ -38,9 +37,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
@@ -185,20 +182,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
|
||||
out_dir = tmpdir / "logger"
|
||||
|
||||
ds_cfg = DatasetConfig(repo_id, local_files_only=True)
|
||||
train_cfg = TrainPipelineConfig(
|
||||
dataset=ds_cfg,
|
||||
policy=policy_cfg,
|
||||
output_dir=out_dir,
|
||||
device=DEVICE,
|
||||
)
|
||||
logger = Logger(train_cfg)
|
||||
logger.save_checkpoint(
|
||||
train_step=0,
|
||||
identifier=0,
|
||||
policy=policy,
|
||||
)
|
||||
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
|
||||
policy.save_pretrained(pretrained_policy_path)
|
||||
|
||||
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
||||
# during inference, to reach constent fps, so we test this here.
|
||||
|
||||
Reference in New Issue
Block a user