Initial commit

This commit is contained in:
Ury Zhilinsky
2024-12-23 13:38:06 -08:00
commit 385780ecc3
121 changed files with 15572 additions and 0 deletions

27
scripts/train_test.py Normal file
View File

@@ -0,0 +1,27 @@
import dataclasses
import pathlib
import pytest
from openpi.training import config as _config
from . import train
@pytest.mark.parametrize("config_name", ["debug"])
def test_train(tmp_path: pathlib.Path, config_name: str):
config = dataclasses.replace(
_config._CONFIGS_DICT[config_name], # noqa: SLF001
batch_size=2,
checkpoint_base_dir=tmp_path / "checkpoint",
exp_name="test",
overwrite=False,
resume=False,
num_train_steps=2,
log_interval=1,
)
train.main(config)
# test resuming
config = dataclasses.replace(config, resume=True, num_train_steps=4)
train.main(config)