Files
issacdataengine/policy/openpi-InternData-A1/src/openpi/training/misc/roboarena_config.py
2026-03-17 23:05:23 +08:00

117 lines
4.7 KiB
Python

"""RoboArena baseline policy configs."""
from typing import TypeAlias
import openpi.models.model as _model
import openpi.models.pi0_config as pi0_config
import openpi.models.pi0_fast as pi0_fast
import openpi.models.tokenizer as _tokenizer
import openpi.policies.droid_policy as droid_policy
import openpi.transforms as _transforms
ModelType: TypeAlias = _model.ModelType
def get_roboarena_configs():
# Import here to avoid circular imports.
from openpi.training.config import AssetsConfig
from openpi.training.config import DataConfig
from openpi.training.config import SimpleDataConfig
from openpi.training.config import TrainConfig
return [
#
# RoboArena DROID baseline inference configs.
#
TrainConfig(
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
name="paligemma_binning_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
max_token_len=400,
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
name="paligemma_fast_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
name="paligemma_fast_specialist_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
fast_model_tokenizer=_tokenizer.FASTTokenizer,
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FSQ tokenizer.
name="paligemma_vq_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
fast_model_tokenizer=_tokenizer.FSQTokenizer,
fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
name="paligemma_diffusion_droid",
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
]