Initial commit
This commit is contained in:
30
scripts/train_test.py
Normal file
30
scripts/train_test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
from openpi.training import config as _config
|
||||
|
||||
from . import train
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["debug"])
|
||||
def test_train(tmp_path: pathlib.Path, config_name: str):
|
||||
config = dataclasses.replace(
|
||||
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
||||
batch_size=2,
|
||||
checkpoint_base_dir=tmp_path / "checkpoint",
|
||||
exp_name="test",
|
||||
overwrite=False,
|
||||
resume=False,
|
||||
num_train_steps=2,
|
||||
log_interval=1,
|
||||
)
|
||||
train.main(config)
|
||||
|
||||
# test resuming
|
||||
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
||||
train.main(config)
|
||||
Reference in New Issue
Block a user