forked from tangger/lerobot
Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
84
tests/test_train_utils.py
Normal file
84
tests/test_train_utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
RNG_STATE,
|
||||
SCHEDULER_STATE,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def test_get_step_identifier():
|
||||
assert get_step_identifier(5, 1000) == "000005"
|
||||
assert get_step_identifier(123, 100_000) == "000123"
|
||||
assert get_step_identifier(456789, 1_000_000) == "0456789"
|
||||
|
||||
|
||||
def test_get_step_checkpoint_dir():
|
||||
output_dir = Path("/checkpoints")
|
||||
step_dir = get_step_checkpoint_dir(output_dir, 1000, 5)
|
||||
assert step_dir == output_dir / CHECKPOINTS_DIR / "000005"
|
||||
|
||||
|
||||
def test_save_load_training_step(tmp_path):
|
||||
save_training_step(5000, tmp_path)
|
||||
assert (tmp_path / TRAINING_STEP).is_file()
|
||||
|
||||
|
||||
def test_load_training_step(tmp_path):
|
||||
step = 5000
|
||||
save_training_step(step, tmp_path)
|
||||
loaded_step = load_training_step(tmp_path)
|
||||
assert loaded_step == step
|
||||
|
||||
|
||||
def test_update_last_checkpoint(tmp_path):
|
||||
checkpoint = tmp_path / "0005"
|
||||
checkpoint.mkdir()
|
||||
update_last_checkpoint(checkpoint)
|
||||
last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK
|
||||
assert last_checkpoint.is_symlink()
|
||||
assert last_checkpoint.resolve() == checkpoint
|
||||
|
||||
|
||||
@patch("lerobot.common.utils.train_utils.save_training_state")
|
||||
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
cfg = Mock()
|
||||
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
|
||||
policy.save_pretrained.assert_called_once()
|
||||
cfg.save_pretrained.assert_called_once()
|
||||
mock_save_training_state.assert_called_once()
|
||||
|
||||
|
||||
def test_save_training_state(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert (tmp_path / TRAINING_STATE_DIR).is_dir()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file()
|
||||
|
||||
|
||||
def test_save_load_training_state(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
Reference in New Issue
Block a user