Initial commit

This commit is contained in:
Ury Zhilinsky
2025-02-03 21:43:26 -08:00
commit 231a1cf7ca
121 changed files with 16349 additions and 0 deletions

30
scripts/train_test.py Normal file
View File

@@ -0,0 +1,30 @@
import dataclasses
import os
import pathlib
import pytest
os.environ["JAX_PLATFORMS"] = "cpu"
from openpi.training import config as _config
from . import train
@pytest.mark.parametrize("config_name", ["debug"])
def test_train(tmp_path: pathlib.Path, config_name: str):
config = dataclasses.replace(
_config._CONFIGS_DICT[config_name], # noqa: SLF001
batch_size=2,
checkpoint_base_dir=tmp_path / "checkpoint",
exp_name="test",
overwrite=False,
resume=False,
num_train_steps=2,
log_interval=1,
)
train.main(config)
# test resuming
config = dataclasses.replace(config, resume=True, num_train_steps=4)
train.main(config)