Add training on custom openpi datasets
Cleanup instructions clean up doc pass linter updates Add test
This commit is contained in:
@@ -63,6 +63,8 @@ uv run scripts/train.py pi0_aloha_sim --exp-name=my_experiment --overwrite
|
||||
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
|
||||
|
||||
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
|
||||
|
||||
An example for how to train on your own Aloha dataset is provided in the [ALOHA Real README](examples/aloha_real/README.md).
|
||||
|
||||
## Running examples
|
||||
|
||||
|
||||
@@ -54,20 +54,4 @@ While we strongly recommend fine-tuning the model to your own data to adapt it t
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name=<your-config-name>,
|
||||
data=LeRobotAlohaDataConfig(
|
||||
repo_id=<your-repo-id>,
|
||||
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
|
||||
),
|
||||
),
|
||||
```
|
||||
|
||||
Run the training script:
|
||||
|
||||
```bash
|
||||
uv run scripts/train.py <your-config-name>
|
||||
```
|
||||
OpenPI suppports training on data collected in the default aloha hdf5 format using the `scripts/aloha_hd5.py` conversion script. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` (see the `aloha_static_cups_open` example config) and replace repo id with the id assigned to your dataset during conversion. Before training on a new dataset, you must first compute the norm stats using `scripts/compute_norm_stats.py`.
|
||||
@@ -16,6 +16,7 @@ from openpi.policies import libero_policy
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.serving import websocket_policy_server
|
||||
from openpi.shared import delta_actions
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
@@ -146,7 +147,7 @@ def create_default_policy(
|
||||
logging.info("Creating policy...")
|
||||
match env:
|
||||
case EnvMode.ALOHA:
|
||||
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
|
||||
delta_action_mask = delta_actions.make_bool_mask(6, -1, 6, -1)
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
aloha_policy.ActInputsRepack(),
|
||||
|
||||
@@ -99,25 +99,3 @@ def create_trained_policy(
|
||||
],
|
||||
sample_kwargs=sample_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
|
||||
"""Make a boolean mask for the given dimensions.
|
||||
|
||||
Example:
|
||||
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
|
||||
Args:
|
||||
dims: The dimensions to make the mask for.
|
||||
|
||||
Returns:
|
||||
A tuple of booleans.
|
||||
"""
|
||||
result = []
|
||||
for dim in dims:
|
||||
if dim > 0:
|
||||
result.extend([True] * (dim))
|
||||
else:
|
||||
result.extend([False] * (-dim))
|
||||
return tuple(result)
|
||||
|
||||
@@ -2,11 +2,6 @@ from openpi.policies import policy_config as _policy_config
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
def test_make_bool_mask():
|
||||
assert _policy_config.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
assert _policy_config.make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
|
||||
|
||||
def test_create_trained_policy():
|
||||
policy = _policy_config.create_trained_policy(
|
||||
_config.get_config("debug"),
|
||||
|
||||
20
src/openpi/shared/delta_actions.py
Normal file
20
src/openpi/shared/delta_actions.py
Normal file
@@ -0,0 +1,20 @@
|
||||
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
|
||||
"""Make a boolean mask for the given dimensions.
|
||||
|
||||
Example:
|
||||
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
|
||||
Args:
|
||||
dims: The dimensions to make the mask for.
|
||||
|
||||
Returns:
|
||||
A tuple of booleans.
|
||||
"""
|
||||
result = []
|
||||
for dim in dims:
|
||||
if dim > 0:
|
||||
result.extend([True] * (dim))
|
||||
else:
|
||||
result.extend([False] * (-dim))
|
||||
return tuple(result)
|
||||
6
src/openpi/shared/delta_actions_test.py
Normal file
6
src/openpi/shared/delta_actions_test.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from openpi.shared import delta_actions
|
||||
|
||||
|
||||
def test_make_bool_mask():
|
||||
assert delta_actions.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
assert delta_actions.make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
@@ -14,6 +14,7 @@ import openpi.models.pi0 as pi0
|
||||
import openpi.models.pi0_small as pi0_small
|
||||
import openpi.models.tokenizer as _tokenizer
|
||||
import openpi.policies.aloha_policy as aloha_policy
|
||||
from openpi.shared import delta_actions
|
||||
import openpi.shared.download as download
|
||||
import openpi.shared.normalize as _normalize
|
||||
import openpi.training.optimizer as _optimizer
|
||||
@@ -45,6 +46,9 @@ class DataConfig:
|
||||
# Indicates where the cached dataset should be stored.
|
||||
dataset_root: str | None = dataclasses.field(default_factory=default_dataset_root)
|
||||
|
||||
# If true, will disable syncing the dataset from the huggingface hub. Allows training on local-only datasets.
|
||||
local_files_only: bool = False
|
||||
|
||||
|
||||
class DataConfigFactory(Protocol):
|
||||
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
|
||||
@@ -70,6 +74,8 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
|
||||
adapt_to_pi: bool = False
|
||||
# Repack transforms. Default is used if not provided.
|
||||
repack_transforms: _transforms.Group | None = None
|
||||
# If true, will disable syncing the dataset from the huggingface hub.
|
||||
local_files_only: bool = False
|
||||
|
||||
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
|
||||
norm_stats_path = metadata_dir / self.repo_id / "norm_stats.json"
|
||||
@@ -115,6 +121,7 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
|
||||
),
|
||||
]
|
||||
),
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
|
||||
@@ -237,6 +244,39 @@ _CONFIGS = [
|
||||
weight_loader=weight_loaders.GoogleViTWeightLoader(),
|
||||
),
|
||||
#
|
||||
# Example configs.
|
||||
#
|
||||
TrainConfig(
|
||||
name="aloha_static_cups_open",
|
||||
data=LeRobotAlohaDataConfig(
|
||||
repo_id="lerobot/aloha_static_cups_open",
|
||||
delta_action_mask=delta_actions.make_bool_mask(6, -1, 6, -1),
|
||||
adapt_to_pi=True,
|
||||
repack_transforms=_transforms.Group(
|
||||
inputs=[
|
||||
_transforms.RepackTransform(
|
||||
{
|
||||
"images": {
|
||||
"cam_high": "observation.images.cam_high",
|
||||
"cam_left_wrist": "observation.images.cam_left_wrist",
|
||||
"cam_right_wrist": "observation.images.cam_right_wrist",
|
||||
},
|
||||
"state": "observation.state",
|
||||
"actions": "action",
|
||||
}
|
||||
)
|
||||
]
|
||||
),
|
||||
# Set this to true if you are using a dataset that is not on the huggingface hub.
|
||||
local_files_only=False,
|
||||
),
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
|
||||
num_train_steps=30_000,
|
||||
batch_size=64,
|
||||
lr_schedule=_optimizer.CosineDecaySchedule(
|
||||
warmup_steps=1_000, peak_lr=2.5e-5, decay_steps=30_000, decay_lr=2.5e-6
|
||||
),
|
||||
),
|
||||
# Debugging configs.
|
||||
#
|
||||
TrainConfig(
|
||||
|
||||
@@ -88,11 +88,14 @@ def create_dataset(data_config: _config.DataConfig, model: _model.Model) -> Data
|
||||
if repo_id == "fake":
|
||||
return FakeDataset(model, num_samples=1024)
|
||||
|
||||
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
|
||||
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(
|
||||
repo_id, root=data_config.dataset_root, local_files_only=data_config.local_files_only
|
||||
)
|
||||
return lerobot_dataset.LeRobotDataset(
|
||||
data_config.repo_id,
|
||||
delta_timestamps={"action": [t / dataset_meta.fps for t in range(model.action_horizon)]},
|
||||
root=data_config.dataset_root,
|
||||
local_files_only=data_config.local_files_only,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user