Add example for training on custom openpi datasets (#140)
This commit is contained in:
@@ -63,6 +63,8 @@ uv run scripts/train.py pi0_aloha_sim --exp-name=my_experiment --overwrite
|
|||||||
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
|
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
|
||||||
|
|
||||||
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
|
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
|
||||||
|
|
||||||
|
An example for how to train on your own Aloha dataset is provided in the [ALOHA Real README](examples/aloha_real/README.md).
|
||||||
|
|
||||||
## Running examples
|
## Running examples
|
||||||
|
|
||||||
|
|||||||
@@ -54,20 +54,6 @@ While we strongly recommend fine-tuning the model to your own data to adapt it t
|
|||||||
|
|
||||||
## Training on your own Aloha dataset
|
## Training on your own Aloha dataset
|
||||||
|
|
||||||
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
|
OpenPI suppports training on data collected in the default aloha hdf5 format using the `scripts/aloha_hd5.py` conversion script. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` (see the `aloha_static_cups_open` example config) and replace repo id with the id assigned to your dataset during conversion. Before training on a new dataset, you must first compute the norm stats using `scripts/compute_norm_stats.py`.
|
||||||
|
|
||||||
```python
|
NOTE: When finetuning the pi0 base model on Aloha data it is recommended that you set `adapt_to_pi=True`. This maps the state and action spaces from the original aloha data to the state and action spaces of the aloha data used to train the base model.
|
||||||
TrainConfig(
|
|
||||||
name=<your-config-name>,
|
|
||||||
data=LeRobotAlohaDataConfig(
|
|
||||||
repo_id=<your-repo-id>,
|
|
||||||
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
```
|
|
||||||
|
|
||||||
Run the training script:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run scripts/train.py <your-config-name>
|
|
||||||
```
|
|
||||||
@@ -16,6 +16,7 @@ from openpi.policies import libero_policy
|
|||||||
from openpi.policies import policy as _policy
|
from openpi.policies import policy as _policy
|
||||||
from openpi.policies import policy_config as _policy_config
|
from openpi.policies import policy_config as _policy_config
|
||||||
from openpi.serving import websocket_policy_server
|
from openpi.serving import websocket_policy_server
|
||||||
|
from openpi.shared import delta_actions
|
||||||
from openpi.training import config as _config
|
from openpi.training import config as _config
|
||||||
|
|
||||||
|
|
||||||
@@ -146,7 +147,7 @@ def create_default_policy(
|
|||||||
logging.info("Creating policy...")
|
logging.info("Creating policy...")
|
||||||
match env:
|
match env:
|
||||||
case EnvMode.ALOHA:
|
case EnvMode.ALOHA:
|
||||||
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
|
delta_action_mask = delta_actions.make_bool_mask(6, -1, 6, -1)
|
||||||
config = make_policy_config(
|
config = make_policy_config(
|
||||||
input_layers=[
|
input_layers=[
|
||||||
aloha_policy.ActInputsRepack(),
|
aloha_policy.ActInputsRepack(),
|
||||||
|
|||||||
@@ -99,25 +99,3 @@ def create_trained_policy(
|
|||||||
],
|
],
|
||||||
sample_kwargs=sample_kwargs,
|
sample_kwargs=sample_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
|
|
||||||
"""Make a boolean mask for the given dimensions.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
|
||||||
make_bool_mask(2, 0, 2) == (True, True, True, True)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dims: The dimensions to make the mask for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of booleans.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for dim in dims:
|
|
||||||
if dim > 0:
|
|
||||||
result.extend([True] * (dim))
|
|
||||||
else:
|
|
||||||
result.extend([False] * (-dim))
|
|
||||||
return tuple(result)
|
|
||||||
|
|||||||
@@ -2,11 +2,6 @@ from openpi.policies import policy_config as _policy_config
|
|||||||
from openpi.training import config as _config
|
from openpi.training import config as _config
|
||||||
|
|
||||||
|
|
||||||
def test_make_bool_mask():
|
|
||||||
assert _policy_config.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
|
||||||
assert _policy_config.make_bool_mask(2, 0, 2) == (True, True, True, True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_trained_policy():
|
def test_create_trained_policy():
|
||||||
policy = _policy_config.create_trained_policy(
|
policy = _policy_config.create_trained_policy(
|
||||||
_config.get_config("debug"),
|
_config.get_config("debug"),
|
||||||
|
|||||||
20
src/openpi/shared/delta_actions.py
Normal file
20
src/openpi/shared/delta_actions.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
|
||||||
|
"""Make a boolean mask for the given dimensions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||||
|
make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims: The dimensions to make the mask for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of booleans.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for dim in dims:
|
||||||
|
if dim > 0:
|
||||||
|
result.extend([True] * (dim))
|
||||||
|
else:
|
||||||
|
result.extend([False] * (-dim))
|
||||||
|
return tuple(result)
|
||||||
6
src/openpi/shared/delta_actions_test.py
Normal file
6
src/openpi/shared/delta_actions_test.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from openpi.shared import delta_actions
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_bool_mask():
|
||||||
|
assert delta_actions.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||||
|
assert delta_actions.make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||||
@@ -14,6 +14,7 @@ import openpi.models.pi0 as pi0
|
|||||||
import openpi.models.pi0_small as pi0_small
|
import openpi.models.pi0_small as pi0_small
|
||||||
import openpi.models.tokenizer as _tokenizer
|
import openpi.models.tokenizer as _tokenizer
|
||||||
import openpi.policies.aloha_policy as aloha_policy
|
import openpi.policies.aloha_policy as aloha_policy
|
||||||
|
from openpi.shared import delta_actions
|
||||||
import openpi.shared.download as download
|
import openpi.shared.download as download
|
||||||
import openpi.shared.normalize as _normalize
|
import openpi.shared.normalize as _normalize
|
||||||
import openpi.training.optimizer as _optimizer
|
import openpi.training.optimizer as _optimizer
|
||||||
@@ -45,6 +46,9 @@ class DataConfig:
|
|||||||
# Indicates where the cached dataset should be stored.
|
# Indicates where the cached dataset should be stored.
|
||||||
dataset_root: str | None = dataclasses.field(default_factory=default_dataset_root)
|
dataset_root: str | None = dataclasses.field(default_factory=default_dataset_root)
|
||||||
|
|
||||||
|
# If true, will disable syncing the dataset from the huggingface hub. Allows training on local-only datasets.
|
||||||
|
local_files_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
class DataConfigFactory(Protocol):
|
class DataConfigFactory(Protocol):
|
||||||
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
|
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
|
||||||
@@ -70,6 +74,8 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
|
|||||||
adapt_to_pi: bool = False
|
adapt_to_pi: bool = False
|
||||||
# Repack transforms. Default is used if not provided.
|
# Repack transforms. Default is used if not provided.
|
||||||
repack_transforms: _transforms.Group | None = None
|
repack_transforms: _transforms.Group | None = None
|
||||||
|
# If true, will disable syncing the dataset from the huggingface hub.
|
||||||
|
local_files_only: bool = False
|
||||||
|
|
||||||
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
|
def create(self, metadata_dir: pathlib.Path, model: _model.Model) -> DataConfig:
|
||||||
norm_stats_path = metadata_dir / self.repo_id / "norm_stats.json"
|
norm_stats_path = metadata_dir / self.repo_id / "norm_stats.json"
|
||||||
@@ -115,6 +121,7 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
local_files_only=self.local_files_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -237,6 +244,39 @@ _CONFIGS = [
|
|||||||
weight_loader=weight_loaders.GoogleViTWeightLoader(),
|
weight_loader=weight_loaders.GoogleViTWeightLoader(),
|
||||||
),
|
),
|
||||||
#
|
#
|
||||||
|
# Example configs.
|
||||||
|
#
|
||||||
|
TrainConfig(
|
||||||
|
name="aloha_static_cups_open",
|
||||||
|
data=LeRobotAlohaDataConfig(
|
||||||
|
repo_id="lerobot/aloha_static_cups_open",
|
||||||
|
delta_action_mask=delta_actions.make_bool_mask(6, -1, 6, -1),
|
||||||
|
adapt_to_pi=True,
|
||||||
|
repack_transforms=_transforms.Group(
|
||||||
|
inputs=[
|
||||||
|
_transforms.RepackTransform(
|
||||||
|
{
|
||||||
|
"images": {
|
||||||
|
"cam_high": "observation.images.cam_high",
|
||||||
|
"cam_left_wrist": "observation.images.cam_left_wrist",
|
||||||
|
"cam_right_wrist": "observation.images.cam_right_wrist",
|
||||||
|
},
|
||||||
|
"state": "observation.state",
|
||||||
|
"actions": "action",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
# Set this to true if you are using a dataset that is not on the huggingface hub.
|
||||||
|
local_files_only=False,
|
||||||
|
),
|
||||||
|
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
|
||||||
|
num_train_steps=30_000,
|
||||||
|
batch_size=64,
|
||||||
|
lr_schedule=_optimizer.CosineDecaySchedule(
|
||||||
|
warmup_steps=1_000, peak_lr=2.5e-5, decay_steps=30_000, decay_lr=2.5e-6
|
||||||
|
),
|
||||||
|
),
|
||||||
# Debugging configs.
|
# Debugging configs.
|
||||||
#
|
#
|
||||||
TrainConfig(
|
TrainConfig(
|
||||||
|
|||||||
@@ -88,11 +88,14 @@ def create_dataset(data_config: _config.DataConfig, model: _model.Model) -> Data
|
|||||||
if repo_id == "fake":
|
if repo_id == "fake":
|
||||||
return FakeDataset(model, num_samples=1024)
|
return FakeDataset(model, num_samples=1024)
|
||||||
|
|
||||||
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
|
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(
|
||||||
|
repo_id, root=data_config.dataset_root, local_files_only=data_config.local_files_only
|
||||||
|
)
|
||||||
return lerobot_dataset.LeRobotDataset(
|
return lerobot_dataset.LeRobotDataset(
|
||||||
data_config.repo_id,
|
data_config.repo_id,
|
||||||
delta_timestamps={"action": [t / dataset_meta.fps for t in range(model.action_horizon)]},
|
delta_timestamps={"action": [t / dataset_meta.fps for t in range(model.action_horizon)]},
|
||||||
root=data_config.dataset_root,
|
root=data_config.dataset_root,
|
||||||
|
local_files_only=data_config.local_files_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user