add binning jointpos

This commit is contained in:
Karl Pertsch
2025-04-25 05:28:04 +00:00
parent c23bc86a0a
commit b84cc75031

View File

@@ -524,6 +524,26 @@ _CONFIGS = [
),
),
),
TrainConfig(
name="paligemma_binning_droid_jointpos",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
droid_policy.DroidOutputs(),
],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
),
),
TrainConfig(
name="paligemma_fast_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),