add binning jointpos
This commit is contained in:
@@ -524,6 +524,26 @@ _CONFIGS = [
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
TrainConfig(
|
||||||
|
name="paligemma_binning_droid_jointpos",
|
||||||
|
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
|
||||||
|
data=SimpleDataConfig(
|
||||||
|
assets=AssetsConfig(asset_id="droid"),
|
||||||
|
data_transforms=lambda model: _transforms.Group(
|
||||||
|
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||||
|
outputs=[
|
||||||
|
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
|
||||||
|
droid_policy.DroidOutputs(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
base_config=DataConfig(
|
||||||
|
prompt_from_task=True,
|
||||||
|
),
|
||||||
|
model_transforms=ModelTransformFactory(
|
||||||
|
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
TrainConfig(
|
TrainConfig(
|
||||||
name="paligemma_fast_droid",
|
name="paligemma_fast_droid",
|
||||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||||
|
|||||||
Reference in New Issue
Block a user