multi-node openpi commit
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
"""RoboArena baseline policy configs."""
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
import openpi.models.pi0_fast as pi0_fast
|
||||
import openpi.models.tokenizer as _tokenizer
|
||||
import openpi.policies.droid_policy as droid_policy
|
||||
import openpi.transforms as _transforms
|
||||
|
||||
ModelType: TypeAlias = _model.ModelType
|
||||
|
||||
|
||||
def get_roboarena_configs():
|
||||
# Import here to avoid circular imports.
|
||||
from openpi.training.config import AssetsConfig
|
||||
from openpi.training.config import DataConfig
|
||||
from openpi.training.config import SimpleDataConfig
|
||||
from openpi.training.config import TrainConfig
|
||||
|
||||
return [
|
||||
#
|
||||
# RoboArena DROID baseline inference configs.
|
||||
#
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
||||
name="paligemma_binning_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
max_token_len=400,
|
||||
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
||||
name="paligemma_fast_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
||||
name="paligemma_fast_specialist_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
fast_model_tokenizer=_tokenizer.FASTTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FSQ tokenizer.
|
||||
name="paligemma_vq_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
fast_model_tokenizer=_tokenizer.FSQTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
||||
name="paligemma_diffusion_droid",
|
||||
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
Reference in New Issue
Block a user