31 lines
713 B
Python
Executable File
31 lines
713 B
Python
Executable File
import dataclasses
|
|
import os
|
|
import pathlib
|
|
|
|
import pytest
|
|
|
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
|
|
|
from openpi.training import config as _config
|
|
|
|
from . import train
|
|
|
|
|
|
@pytest.mark.parametrize("config_name", ["debug"])
|
|
def test_train(tmp_path: pathlib.Path, config_name: str):
|
|
config = dataclasses.replace(
|
|
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
|
batch_size=2,
|
|
checkpoint_base_dir=tmp_path / "checkpoint",
|
|
exp_name="test",
|
|
overwrite=False,
|
|
resume=False,
|
|
num_train_steps=2,
|
|
log_interval=1,
|
|
)
|
|
train.main(config)
|
|
|
|
# test resuming
|
|
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
|
train.main(config)
|