Add training on custom openpi datasets

Cleanup instructions

clean up doc

pass linter

updates

Add test
This commit is contained in:
Michael Equi
2024-12-22 19:19:54 +00:00
parent 385780ecc3
commit 9da84a2f7f
9 changed files with 75 additions and 46 deletions

View File

@@ -63,6 +63,8 @@ uv run scripts/train.py pi0_aloha_sim --exp-name=my_experiment --overwrite
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
An example for how to train on your own Aloha dataset is provided in the [ALOHA Real README](examples/aloha_real/README.md).
## Running examples