Add example for training on custom openpi datasets (#140)

This commit is contained in:
Michael Equi
2024-12-23 16:09:03 -08:00
committed by GitHub
9 changed files with 76 additions and 45 deletions

View File

@@ -63,6 +63,8 @@ uv run scripts/train.py pi0_aloha_sim --exp-name=my_experiment --overwrite
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
An example for how to train on your own Aloha dataset is provided in the [ALOHA Real README](examples/aloha_real/README.md).
## Running examples

View File

@@ -54,20 +54,6 @@ While we strongly recommend fine-tuning the model to your own data to adapt it t
## Training on your own Aloha dataset
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
OpenPI suppports training on data collected in the default aloha hdf5 format using the `scripts/aloha_hd5.py` conversion script. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` (see the `aloha_static_cups_open` example config) and replace repo id with the id assigned to your dataset during conversion. Before training on a new dataset, you must first compute the norm stats using `scripts/compute_norm_stats.py`.
```python
TrainConfig(
name=<your-config-name>,
data=LeRobotAlohaDataConfig(
repo_id=<your-repo-id>,
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
),
),
```
Run the training script:
```bash
uv run scripts/train.py <your-config-name>
```
NOTE: When finetuning the pi0 base model on Aloha data it is recommended that you set `adapt_to_pi=True`. This maps the state and action spaces from the original aloha data to the state and action spaces of the aloha data used to train the base model.

View File

@@ -16,6 +16,7 @@ from openpi.policies import libero_policy
from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.shared import delta_actions
from openpi.training import config as _config
@@ -146,7 +147,7 @@ def create_default_policy(
logging.info("Creating policy...")
match env:
case EnvMode.ALOHA:
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
delta_action_mask = delta_actions.make_bool_mask(6, -1, 6, -1)
config = make_policy_config(
input_layers=[
aloha_policy.ActInputsRepack(),

View File

@@ -99,25 +99,3 @@ def create_trained_policy(
],
sample_kwargs=sample_kwargs,
)
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
"""Make a boolean mask for the given dimensions.
Example:
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
make_bool_mask(2, 0, 2) == (True, True, True, True)
Args:
dims: The dimensions to make the mask for.
Returns:
A tuple of booleans.
"""
result = []
for dim in dims:
if dim > 0:
result.extend([True] * (dim))
else:
result.extend([False] * (-dim))
return tuple(result)

View File

@@ -2,11 +2,6 @@ from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config
def test_make_bool_mask():
assert _policy_config.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
assert _policy_config.make_bool_mask(2, 0, 2) == (True, True, True, True)
def test_create_trained_policy():
policy = _policy_config.create_trained_policy(
_config.get_config("debug"),

View File

@@ -0,0 +1,20 @@
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
"""Make a boolean mask for the given dimensions.
Example:
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
make_bool_mask(2, 0, 2) == (True, True, True, True)
Args:
dims: The dimensions to make the mask for.
Returns:
A tuple of booleans.
"""
result = []
for dim in dims:
if dim > 0:
result.extend([True] * (dim))
else:
result.extend([False] * (-dim))
return tuple(result)

View File

@@ -0,0 +1,6 @@
from openpi.shared import delta_actions
def test_make_bool_mask():
assert delta_actions.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
assert delta_actions.make_bool_mask(2, 0, 2) == (True, True, True, True)

View File

@@ -14,6 +14,7 @@ import openpi.models.pi0 as pi0
import openpi.models.pi0_small as pi0_small
import openpi.models.tokenizer as _tokenizer
import openpi.policies.aloha_policy as aloha_policy
from openpi.shared import delta_actions
import openpi.shared.download as download
import openpi.shared.normalize as _normalize
import openpi.training.optimizer as _optimizer
@@ -45,6 +46,9 @@ class DataConfig:
# Indicates where the cached dataset should be stored.
dataset_root: str | None = dataclasses.field(default_factory=default_dataset_root)
# If true, will disable syncing the dataset from the huggingface hub. Allows training on local-only datasets.
local_files_only: bool = False
class DataConfigFactory(Protocol):
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
@@ -70,6 +74,8 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
adapt_to_pi: bool = False
# Repack transforms. Default is used if not provided.
repack_transforms: _transforms.Group | None = None
# If true, will disable syncing the dataset from the huggingface hub.
local_files_only: bool = False
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
norm_stats_path = metadata_dir / self.repo_id / "norm_stats.json"
@@ -115,6 +121,7 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
),
]
),
local_files_only=self.local_files_only,
)
@@ -237,6 +244,39 @@ _CONFIGS = [
weight_loader=weight_loaders.GoogleViTWeightLoader(),
),
#
# Example configs.
#
TrainConfig(
name="aloha_static_cups_open",
data=LeRobotAlohaDataConfig(
repo_id="lerobot/aloha_static_cups_open",
delta_action_mask=delta_actions.make_bool_mask(6, -1, 6, -1),
adapt_to_pi=True,
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
}
)
]
),
# Set this to true if you are using a dataset that is not on the huggingface hub.
local_files_only=False,
),
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
batch_size=64,
lr_schedule=_optimizer.CosineDecaySchedule(
warmup_steps=1_000, peak_lr=2.5e-5, decay_steps=30_000, decay_lr=2.5e-6
),
),
# Debugging configs.
#
TrainConfig(

View File

@@ -88,11 +88,14 @@ def create_dataset(data_config: _config.DataConfig, model: _model.Model) -> Data
if repo_id == "fake":
return FakeDataset(model, num_samples=1024)
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(
repo_id, root=data_config.dataset_root, local_files_only=data_config.local_files_only
)
return lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={"action": [t / dataset_meta.fps for t in range(model.action_horizon)]},
root=data_config.dataset_root,
local_files_only=data_config.local_files_only,
)