Add training on custom openpi datasets

Cleanup instructions

clean up doc

pass linter

updates

Add test
This commit is contained in:
Michael Equi
2024-12-22 19:19:54 +00:00
parent 385780ecc3
commit 9da84a2f7f
9 changed files with 75 additions and 46 deletions

View File

@@ -63,6 +63,8 @@ uv run scripts/train.py pi0_aloha_sim --exp-name=my_experiment --overwrite
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
An example for how to train on your own Aloha dataset is provided in the [ALOHA Real README](examples/aloha_real/README.md).
## Running examples

View File

@@ -54,20 +54,4 @@ While we strongly recommend fine-tuning the model to your own data to adapt it t
## Training on your own Aloha dataset
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
```python
TrainConfig(
name=<your-config-name>,
data=LeRobotAlohaDataConfig(
repo_id=<your-repo-id>,
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
),
),
```
Run the training script:
```bash
uv run scripts/train.py <your-config-name>
```
OpenPI suppports training on data collected in the default aloha hdf5 format using the `scripts/aloha_hd5.py` conversion script. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` (see the `aloha_static_cups_open` example config) and replace repo id with the id assigned to your dataset during conversion. Before training on a new dataset, you must first compute the norm stats using `scripts/compute_norm_stats.py`.

View File

@@ -16,6 +16,7 @@ from openpi.policies import libero_policy
from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.shared import delta_actions
from openpi.training import config as _config
@@ -146,7 +147,7 @@ def create_default_policy(
logging.info("Creating policy...")
match env:
case EnvMode.ALOHA:
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
delta_action_mask = delta_actions.make_bool_mask(6, -1, 6, -1)
config = make_policy_config(
input_layers=[
aloha_policy.ActInputsRepack(),

View File

@@ -99,25 +99,3 @@ def create_trained_policy(
],
sample_kwargs=sample_kwargs,
)
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
"""Make a boolean mask for the given dimensions.
Example:
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
make_bool_mask(2, 0, 2) == (True, True, True, True)
Args:
dims: The dimensions to make the mask for.
Returns:
A tuple of booleans.
"""
result = []
for dim in dims:
if dim > 0:
result.extend([True] * (dim))
else:
result.extend([False] * (-dim))
return tuple(result)

View File

@@ -2,11 +2,6 @@ from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config
def test_make_bool_mask():
assert _policy_config.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
assert _policy_config.make_bool_mask(2, 0, 2) == (True, True, True, True)
def test_create_trained_policy():
policy = _policy_config.create_trained_policy(
_config.get_config("debug"),

View File

@@ -0,0 +1,20 @@
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
"""Make a boolean mask for the given dimensions.
Example:
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
make_bool_mask(2, 0, 2) == (True, True, True, True)
Args:
dims: The dimensions to make the mask for.
Returns:
A tuple of booleans.
"""
result = []
for dim in dims:
if dim > 0:
result.extend([True] * (dim))
else:
result.extend([False] * (-dim))
return tuple(result)

View File

@@ -0,0 +1,6 @@
from openpi.shared import delta_actions
def test_make_bool_mask():
assert delta_actions.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
assert delta_actions.make_bool_mask(2, 0, 2) == (True, True, True, True)

View File

@@ -14,6 +14,7 @@ import openpi.models.pi0 as pi0
import openpi.models.pi0_small as pi0_small
import openpi.models.tokenizer as _tokenizer
import openpi.policies.aloha_policy as aloha_policy
from openpi.shared import delta_actions
import openpi.shared.download as download
import openpi.shared.normalize as _normalize
import openpi.training.optimizer as _optimizer
@@ -45,6 +46,9 @@ class DataConfig:
# Indicates where the cached dataset should be stored.
dataset_root: str | None = dataclasses.field(default_factory=default_dataset_root)
# If true, will disable syncing the dataset from the huggingface hub. Allows training on local-only datasets.
local_files_only: bool = False
class DataConfigFactory(Protocol):
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
@@ -70,6 +74,8 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
adapt_to_pi: bool = False
# Repack transforms. Default is used if not provided.
repack_transforms: _transforms.Group | None = None
# If true, will disable syncing the dataset from the huggingface hub.
local_files_only: bool = False
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
norm_stats_path = metadata_dir / self.repo_id / "norm_stats.json"
@@ -115,6 +121,7 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
),
]
),
local_files_only=self.local_files_only,
)
@@ -237,6 +244,39 @@ _CONFIGS = [
weight_loader=weight_loaders.GoogleViTWeightLoader(),
),
#
# Example configs.
#
TrainConfig(
name="aloha_static_cups_open",
data=LeRobotAlohaDataConfig(
repo_id="lerobot/aloha_static_cups_open",
delta_action_mask=delta_actions.make_bool_mask(6, -1, 6, -1),
adapt_to_pi=True,
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
}
)
]
),
# Set this to true if you are using a dataset that is not on the huggingface hub.
local_files_only=False,
),
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
batch_size=64,
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000, peak_lr=2.5e-5, decay_steps=30_000, decay_lr=2.5e-6
),
),
# Debugging configs.
#
TrainConfig(

View File

@@ -88,11 +88,14 @@ def create_dataset(data_config: _config.DataConfig, model: _model.Model) -> Data
if repo_id == "fake":
return FakeDataset(model, num_samples=1024)
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(
repo_id, root=data_config.dataset_root, local_files_only=data_config.local_files_only
)
return lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={"action": [t / dataset_meta.fps for t in range(model.action_horizon)]},
root=data_config.dataset_root,
local_files_only=data_config.local_files_only,
)