add diffusion jointpos policy

This commit is contained in:
Karl Pertsch
2025-04-17 13:19:48 +00:00
parent e43516e719
commit 650b02e4ca

View File

@@ -476,6 +476,20 @@ _CONFIGS = [
),
),
),
TrainConfig(
name="pi0_droid_jointpos",
model=pi0.Pi0Config(action_horizon=10),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)), droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
name="pi0_fast_droid_jointpos",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),