Files
openpi/scripts/train_test.py
lzy 65d864861b
Some checks are pending
pre-commit / pre-commit (push) Waiting to run
add
2025-04-26 22:10:42 +08:00

31 lines
713 B
Python
Executable File

import dataclasses
import os
import pathlib
import pytest
os.environ["JAX_PLATFORMS"] = "cpu"
from openpi.training import config as _config
from . import train
@pytest.mark.parametrize("config_name", ["debug"])
def test_train(tmp_path: pathlib.Path, config_name: str):
config = dataclasses.replace(
_config._CONFIGS_DICT[config_name], # noqa: SLF001
batch_size=2,
checkpoint_base_dir=tmp_path / "checkpoint",
exp_name="test",
overwrite=False,
resume=False,
num_train_steps=2,
log_interval=1,
)
train.main(config)
# test resuming
config = dataclasses.replace(config, resume=True, num_train_steps=4)
train.main(config)