chore(policies): deprecate pi0fast (#2203)
This commit is contained in:
@@ -31,7 +31,6 @@ from lerobot.envs.utils import env_to_policy_features
|
|||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||||
@@ -58,7 +57,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||||
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
|
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The policy class corresponding to the given name.
|
The policy class corresponding to the given name.
|
||||||
@@ -82,10 +81,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||||
|
|
||||||
return VQBeTPolicy
|
return VQBeTPolicy
|
||||||
elif name == "pi0fast":
|
|
||||||
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
|
||||||
|
|
||||||
return PI0FASTPolicy
|
|
||||||
elif name == "pi0":
|
elif name == "pi0":
|
||||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||||
|
|
||||||
@@ -119,7 +114,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||||
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
|
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||||
"reward_classifier".
|
"reward_classifier".
|
||||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
@@ -137,8 +132,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return ACTConfig(**kwargs)
|
return ACTConfig(**kwargs)
|
||||||
elif policy_type == "vqbet":
|
elif policy_type == "vqbet":
|
||||||
return VQBeTConfig(**kwargs)
|
return VQBeTConfig(**kwargs)
|
||||||
elif policy_type == "pi0fast":
|
|
||||||
return PI0FASTConfig(**kwargs)
|
|
||||||
elif policy_type == "pi0":
|
elif policy_type == "pi0":
|
||||||
return PI0Config(**kwargs)
|
return PI0Config(**kwargs)
|
||||||
elif policy_type == "pi05":
|
elif policy_type == "pi05":
|
||||||
@@ -260,14 +253,6 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
|
||||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
|
||||||
|
|
||||||
processors = make_pi0fast_pre_post_processors(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(policy_cfg, PI0Config):
|
elif isinstance(policy_cfg, PI0Config):
|
||||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||||
|
|
||||||
|
|||||||
@@ -1,153 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|
||||||
from lerobot.optim.optimizers import AdamWConfig
|
|
||||||
from lerobot.optim.schedulers import (
|
|
||||||
CosineDecayWithWarmupSchedulerConfig,
|
|
||||||
)
|
|
||||||
from lerobot.utils.constants import OBS_IMAGES
|
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi0fast")
|
|
||||||
@dataclass
|
|
||||||
class PI0FASTConfig(PreTrainedConfig):
|
|
||||||
# Input / output structure.
|
|
||||||
n_obs_steps: int = 1
|
|
||||||
chunk_size: int = 10
|
|
||||||
n_action_steps: int = 5
|
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"VISUAL": NormalizationMode.IDENTITY,
|
|
||||||
"STATE": NormalizationMode.MEAN_STD,
|
|
||||||
"ACTION": NormalizationMode.MEAN_STD,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Shorter state and action vectors will be padded
|
|
||||||
max_state_dim: int = 32 # 32
|
|
||||||
max_action_dim: int = 32 # 32
|
|
||||||
|
|
||||||
# Image preprocessing
|
|
||||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
|
||||||
interpolate_like_pi: bool = False
|
|
||||||
|
|
||||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
|
||||||
# left and right wrist cameras in addition to the top camera.
|
|
||||||
empty_cameras: int = 0
|
|
||||||
|
|
||||||
# Converts the joint and gripper values from the standard Aloha space to
|
|
||||||
# the space used by the pi internal runtime which was used to train the base model.
|
|
||||||
adapt_to_pi_aloha: bool = False
|
|
||||||
|
|
||||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
|
||||||
# Gripper dimensions will remain in absolute values.
|
|
||||||
use_delta_joint_actions_aloha: bool = False
|
|
||||||
|
|
||||||
# Tokenizer
|
|
||||||
tokenizer_max_length: int = 48
|
|
||||||
|
|
||||||
# Projector
|
|
||||||
proj_width: int = 1024
|
|
||||||
|
|
||||||
# Decoding
|
|
||||||
max_decoding_steps: int = 256
|
|
||||||
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
|
||||||
max_input_seq_len: int = 256 # 512
|
|
||||||
|
|
||||||
# Utils
|
|
||||||
use_cache: bool = True
|
|
||||||
|
|
||||||
# Frozen parameters
|
|
||||||
freeze_vision_encoder: bool = True
|
|
||||||
freeze_lm_head: bool = True
|
|
||||||
|
|
||||||
# Training presets
|
|
||||||
optimizer_lr: float = 1e-4
|
|
||||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
|
||||||
optimizer_eps: float = 1e-8
|
|
||||||
optimizer_weight_decay: float = 1e-5
|
|
||||||
|
|
||||||
scheduler_warmup_steps: int = 1_000
|
|
||||||
scheduler_decay_steps: int = 30_000
|
|
||||||
scheduler_decay_lr: float = 2.5e-6
|
|
||||||
|
|
||||||
checkpoint_path: str = None
|
|
||||||
|
|
||||||
padding_side: str = "right"
|
|
||||||
|
|
||||||
precision: str = "bfloat16"
|
|
||||||
grad_clip_norm: float = 1
|
|
||||||
|
|
||||||
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
|
||||||
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
|
||||||
relaxed_action_decoding: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__post_init__()
|
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
|
||||||
if self.n_action_steps > self.chunk_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
|
||||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
|
||||||
)
|
|
||||||
if self.n_obs_steps != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
|
||||||
for i in range(self.empty_cameras):
|
|
||||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
|
||||||
empty_camera = PolicyFeature(
|
|
||||||
type=FeatureType.VISUAL,
|
|
||||||
shape=(3, 480, 640),
|
|
||||||
)
|
|
||||||
self.input_features[key] = empty_camera
|
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamWConfig:
|
|
||||||
return AdamWConfig(
|
|
||||||
lr=self.optimizer_lr,
|
|
||||||
betas=self.optimizer_betas,
|
|
||||||
eps=self.optimizer_eps,
|
|
||||||
weight_decay=self.optimizer_weight_decay,
|
|
||||||
grad_clip_norm=self.grad_clip_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_scheduler_preset(self):
|
|
||||||
return CosineDecayWithWarmupSchedulerConfig(
|
|
||||||
peak_lr=self.optimizer_lr,
|
|
||||||
decay_lr=self.scheduler_decay_lr,
|
|
||||||
num_warmup_steps=self.scheduler_warmup_steps,
|
|
||||||
num_decay_steps=self.scheduler_decay_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def observation_delta_indices(self) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def action_delta_indices(self) -> list:
|
|
||||||
return list(range(self.chunk_size))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reward_delta_indices(self) -> None:
|
|
||||||
return None
|
|
||||||
@@ -1,980 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
|
|
||||||
|
|
||||||
[Paper](https://huggingface.co/papers/2501.09747)
|
|
||||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
|
||||||
|
|
||||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
|
||||||
Disclaimer: It is not expected to perform as well as the original implementation.
|
|
||||||
|
|
||||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
|
||||||
```bash
|
|
||||||
lerobot-train \
|
|
||||||
--policy.path=lerobot/pi0fast_base \
|
|
||||||
--dataset.repo_id=danaaubakirova/koch_test
|
|
||||||
```
|
|
||||||
|
|
||||||
Example of training the pi0+FAST neural network with from scratch:
|
|
||||||
```bash
|
|
||||||
lerobot-train \
|
|
||||||
--policy.type=pi0fast \
|
|
||||||
--dataset.repo_id=danaaubakirova/koch_test
|
|
||||||
```
|
|
||||||
|
|
||||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
|
||||||
```python
|
|
||||||
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
|
|
||||||
```
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from PIL import Image
|
|
||||||
from scipy.fft import idct
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
|
|
||||||
from transformers.cache_utils import HybridCache, StaticCache
|
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
|
||||||
|
|
||||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
|
||||||
|
|
||||||
PRECISION = {
|
|
||||||
"float16": torch.float16,
|
|
||||||
"float32": torch.float32,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(x, min_val, max_val):
|
|
||||||
return (x - min_val) / (max_val - min_val)
|
|
||||||
|
|
||||||
|
|
||||||
def unnormalize(x, min_val, max_val):
|
|
||||||
return x * (max_val - min_val) + min_val
|
|
||||||
|
|
||||||
|
|
||||||
def safe_arcsin(value):
|
|
||||||
# This ensures that the input stays within
|
|
||||||
# [−1,1] to avoid invalid values for arcsin
|
|
||||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
|
||||||
|
|
||||||
|
|
||||||
def aloha_gripper_to_angular(value):
|
|
||||||
# Aloha transforms the gripper positions into a linear space. The following code
|
|
||||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
|
||||||
# angular space.
|
|
||||||
#
|
|
||||||
# These values are coming from the Aloha code:
|
|
||||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
|
||||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
|
||||||
|
|
||||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
|
||||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
|
||||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
|
||||||
return safe_arcsin(value)
|
|
||||||
|
|
||||||
# The constants are taken from the Interbotix code.
|
|
||||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
|
||||||
|
|
||||||
# Normalize to [0, 1].
|
|
||||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
|
||||||
return normalize(value, min_val=0.4, max_val=1.5)
|
|
||||||
|
|
||||||
|
|
||||||
def aloha_gripper_from_angular(value):
|
|
||||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
|
||||||
# Note that the units are still angular but the range is different.
|
|
||||||
|
|
||||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
|
||||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
|
||||||
|
|
||||||
# These values are coming from the Aloha code:
|
|
||||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
|
||||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
|
||||||
|
|
||||||
|
|
||||||
def aloha_gripper_from_angular_inv(value):
|
|
||||||
# Directly inverts the gripper_from_angular function.
|
|
||||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
|
||||||
return normalize(value, min_val=0.4, max_val=1.5)
|
|
||||||
|
|
||||||
|
|
||||||
class PI0FASTPolicy(PreTrainedPolicy):
|
|
||||||
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
|
|
||||||
|
|
||||||
config_class = PI0FASTConfig
|
|
||||||
name = "pi0fast"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PI0FASTConfig,
|
|
||||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
|
||||||
the configuration class is used.
|
|
||||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
|
||||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
super().__init__(config)
|
|
||||||
config.validate_features()
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
|
||||||
self.model = PI0FAST(config)
|
|
||||||
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""This should be called whenever the environment is reset."""
|
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
"""Override the from_pretrained method to display important disclaimer."""
|
|
||||||
print(
|
|
||||||
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
|
|
||||||
" It is not expected to perform as well as the original implementation. \n"
|
|
||||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
|
||||||
)
|
|
||||||
return super().from_pretrained(*args, **kwargs)
|
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
|
||||||
return self.parameters()
|
|
||||||
|
|
||||||
def _pi_aloha_decode_state(self, state):
|
|
||||||
# Flip the joints.
|
|
||||||
for motor_idx in [1, 2, 8, 9]:
|
|
||||||
state[:, motor_idx] *= -1
|
|
||||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
|
||||||
for motor_idx in [6, 13]:
|
|
||||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
|
||||||
return state
|
|
||||||
|
|
||||||
def _pi_aloha_encode_actions(self, actions):
|
|
||||||
# Flip the joints.
|
|
||||||
for motor_idx in [1, 2, 8, 9]:
|
|
||||||
actions[:, :, motor_idx] *= -1
|
|
||||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
|
||||||
for motor_idx in [6, 13]:
|
|
||||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
|
||||||
return actions
|
|
||||||
|
|
||||||
def _pi_aloha_encode_actions_inv(self, actions):
|
|
||||||
# Flip the joints again.
|
|
||||||
for motor_idx in [1, 2, 8, 9]:
|
|
||||||
actions[:, :, motor_idx] *= -1
|
|
||||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
|
||||||
for motor_idx in [6, 13]:
|
|
||||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
|
||||||
return actions
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Predict a chunk of actions given environment observations."""
|
|
||||||
raise NotImplementedError("Currently not implemented for PI0FAST")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Select a single action given environment observations.
|
|
||||||
|
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
|
||||||
queue is empty.
|
|
||||||
"""
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
if self.config.adapt_to_pi_aloha:
|
|
||||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
|
||||||
|
|
||||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
|
||||||
# querying the policy.
|
|
||||||
if len(self._action_queue) == 0:
|
|
||||||
actions = self.model.generate_actions(batch)
|
|
||||||
|
|
||||||
actions = actions[:, : self.config.n_action_steps]
|
|
||||||
|
|
||||||
original_action_dim = self.config.action_feature.shape[
|
|
||||||
0
|
|
||||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
|
||||||
actions = actions[:, :, :original_action_dim]
|
|
||||||
|
|
||||||
if self.config.adapt_to_pi_aloha:
|
|
||||||
actions = self._pi_aloha_encode_actions(actions)
|
|
||||||
|
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
|
||||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
|
||||||
return self._action_queue.popleft()
|
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
if self.config.adapt_to_pi_aloha:
|
|
||||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
|
||||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
|
||||||
loss_dict = self.model.forward(batch)
|
|
||||||
return loss_dict["loss"], loss_dict
|
|
||||||
|
|
||||||
|
|
||||||
def block_causal_update_causal_mask(
|
|
||||||
attention_mask,
|
|
||||||
token_type_ids=None,
|
|
||||||
past_key_values=None,
|
|
||||||
cache_position=None,
|
|
||||||
input_tensor=None,
|
|
||||||
attn_implementation: str = "eager",
|
|
||||||
dtype: torch.dtype = "float32",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the causal mask during training and generation. It can be customized to different attention masks.
|
|
||||||
"""
|
|
||||||
if attn_implementation == "flash_attention_2":
|
|
||||||
if attention_mask is not None and 0.0 in attention_mask:
|
|
||||||
return attention_mask
|
|
||||||
return None
|
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
|
|
||||||
if input_tensor is None:
|
|
||||||
input_tensor = attention_mask
|
|
||||||
|
|
||||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
|
||||||
|
|
||||||
if using_static_cache or isinstance(past_key_values, HybridCache):
|
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
|
||||||
else:
|
|
||||||
target_length = (
|
|
||||||
attention_mask.shape[-1]
|
|
||||||
if isinstance(attention_mask, torch.Tensor)
|
|
||||||
else cache_position[0] + sequence_length + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle precomputed attention masks
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
# Causal mask initialization
|
|
||||||
causal_mask = torch.full(
|
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Standard causal masking (triu ensures tokens can only attend to past)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
|
|
||||||
# Apply block causal mask
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_ids = token_type_ids.to(causal_mask.device).bool()
|
|
||||||
cumsum = torch.cumsum(token_type_ids, dim=1)
|
|
||||||
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
||||||
|
|
||||||
# Combine causal_mask with block-wise attention mask
|
|
||||||
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
|
|
||||||
causal_mask = causal_mask[:, None, :, :]
|
|
||||||
else:
|
|
||||||
# Apply past cache position constraint
|
|
||||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
|
||||||
-1, 1
|
|
||||||
)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
|
||||||
else:
|
|
||||||
# Apply past cache position constraint
|
|
||||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
|
||||||
-1, 1
|
|
||||||
)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
|
|
||||||
# Apply padding mask
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
|
||||||
causal_mask.device
|
|
||||||
)
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
|
||||||
# self,
|
|
||||||
input_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=None,
|
|
||||||
cache_position=None,
|
|
||||||
position_ids=None,
|
|
||||||
pixel_values=None,
|
|
||||||
attention_mask=None,
|
|
||||||
token_type_ids=None,
|
|
||||||
use_cache=True,
|
|
||||||
num_logits_to_keep=None,
|
|
||||||
labels=None,
|
|
||||||
self=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# create block causal attention
|
|
||||||
if cache_position[0] > 0 and input_ids.shape[1] > 0:
|
|
||||||
input_tensor = input_ids[:, -1:]
|
|
||||||
new_positions = (
|
|
||||||
torch.ones(
|
|
||||||
(position_ids.shape[0], input_ids.shape[1]),
|
|
||||||
dtype=position_ids.dtype,
|
|
||||||
device=position_ids.device,
|
|
||||||
).cumsum(-1)
|
|
||||||
+ position_ids[:, -1:]
|
|
||||||
)
|
|
||||||
position_ids = torch.cat([position_ids, new_positions], dim=-1)
|
|
||||||
else:
|
|
||||||
input_tensor = inputs_embeds
|
|
||||||
attention_mask = block_causal_update_causal_mask(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
cache_position=cache_position,
|
|
||||||
input_tensor=input_tensor,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
dtype=self.dtype,
|
|
||||||
attn_implementation=self.config.text_config._attn_implementation,
|
|
||||||
)
|
|
||||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
|
||||||
input_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cache_position=cache_position,
|
|
||||||
use_cache=use_cache,
|
|
||||||
num_logits_to_keep=num_logits_to_keep,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Position_ids in Paligemma are 1-indexed
|
|
||||||
if model_inputs.get("position_ids") is not None:
|
|
||||||
model_inputs["position_ids"] += 1
|
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
|
||||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
|
||||||
if cache_position[0] == 0:
|
|
||||||
model_inputs["pixel_values"] = pixel_values
|
|
||||||
is_training = token_type_ids is not None and labels is not None
|
|
||||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
|
||||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
|
||||||
causal_mask = self._update_causal_mask(
|
|
||||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
|
||||||
)
|
|
||||||
model_inputs["attention_mask"] = causal_mask
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
class PI0FAST(nn.Module):
|
|
||||||
def __init__(self, config: PI0FASTConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# TODO: move tokenizers in Policy
|
|
||||||
fast_tokenizer_path = "physical-intelligence/fast"
|
|
||||||
pi0_paligemma_path = "google/paligemma-3b-pt-224"
|
|
||||||
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
|
|
||||||
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
|
||||||
self.fast_skip_tokens = self.config.fast_skip_tokens
|
|
||||||
self.max_input_seq_len = self.config.max_input_seq_len
|
|
||||||
self.action_horizon = self.config.chunk_size
|
|
||||||
self.action_dim = self.config.action_feature.shape[
|
|
||||||
0
|
|
||||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
|
||||||
precision = config.precision
|
|
||||||
torch_precision = PRECISION.get(precision, torch.float32)
|
|
||||||
self.pad_token_id = (
|
|
||||||
self.paligemma_tokenizer.pad_token_id
|
|
||||||
if hasattr(self.paligemma_tokenizer, "pad_token_id")
|
|
||||||
else self.paligemma_tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
paligemma_config = CONFIG_MAPPING["paligemma"](
|
|
||||||
transformers_version="4.48.1",
|
|
||||||
_vocab_size=257152,
|
|
||||||
bos_token_id=2,
|
|
||||||
eos_token_id=1,
|
|
||||||
hidden_size=2048,
|
|
||||||
image_token_index=257152,
|
|
||||||
model_type="paligemma",
|
|
||||||
pad_token_id=0,
|
|
||||||
projection_dim=2048,
|
|
||||||
text_config={
|
|
||||||
"hidden_activation": "gelu_pytorch_tanh",
|
|
||||||
"hidden_size": 2048,
|
|
||||||
"intermediate_size": 16384,
|
|
||||||
"model_type": "gemma",
|
|
||||||
"num_attention_heads": 8,
|
|
||||||
"num_hidden_layers": 18,
|
|
||||||
"num_image_tokens": 256,
|
|
||||||
"num_key_value_heads": 1,
|
|
||||||
"torch_dtype": precision,
|
|
||||||
"vocab_size": 257152,
|
|
||||||
"_attn_implementation": "eager",
|
|
||||||
},
|
|
||||||
vision_config={
|
|
||||||
"hidden_size": 1152,
|
|
||||||
"intermediate_size": 4304,
|
|
||||||
"model_type": "siglip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_hidden_layers": 27,
|
|
||||||
"num_image_tokens": 256,
|
|
||||||
"patch_size": 14,
|
|
||||||
"projection_dim": 2048,
|
|
||||||
"projector_hidden_act": "gelu_pytorch_tanh",
|
|
||||||
"torch_dtype": precision,
|
|
||||||
"vision_use_head": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
|
||||||
|
|
||||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
|
||||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
|
||||||
)
|
|
||||||
# change important stuff in bf16
|
|
||||||
params_to_change_dtype = [
|
|
||||||
"language_model",
|
|
||||||
"vision_tower",
|
|
||||||
"multi_modal",
|
|
||||||
]
|
|
||||||
for name, param in self.pi0_paligemma.named_parameters():
|
|
||||||
if any(selector in name for selector in params_to_change_dtype):
|
|
||||||
param.data = param.data.to(dtype=torch_precision)
|
|
||||||
self.set_requires_grad()
|
|
||||||
self.image_keys = self.config.image_features.keys()
|
|
||||||
# TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed
|
|
||||||
# AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index'
|
|
||||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
|
||||||
self.padding_side = self.config.padding_side
|
|
||||||
|
|
||||||
def set_requires_grad(self):
|
|
||||||
if self.config.freeze_vision_encoder:
|
|
||||||
self.pi0_paligemma.vision_tower.eval()
|
|
||||||
for params in self.pi0_paligemma.vision_tower.parameters():
|
|
||||||
params.requires_grad = False
|
|
||||||
# To avoid unused params issue with distributed training
|
|
||||||
if self.config.freeze_lm_head:
|
|
||||||
for name, params in self.pi0_paligemma.named_parameters():
|
|
||||||
if "embed_tokens" in name: # lm heads and embedding layer are tied
|
|
||||||
params.requires_grad = False
|
|
||||||
|
|
||||||
def embed_tokens(self, tokens: torch.Tensor):
|
|
||||||
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
|
||||||
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
|
|
||||||
|
|
||||||
def prepare_images(self, batch):
|
|
||||||
"""Preprocess LeRobot batch into Pi0 inputs"""
|
|
||||||
images = []
|
|
||||||
img_masks = []
|
|
||||||
present_img_keys = [key for key in self.image_keys if key in batch]
|
|
||||||
if len(present_img_keys) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Preprocess image features present in the batch
|
|
||||||
num_empty_cameras = 0
|
|
||||||
for key in self.image_keys:
|
|
||||||
if key in present_img_keys:
|
|
||||||
img = batch[key]
|
|
||||||
|
|
||||||
if self.config.resize_imgs_with_padding is not None:
|
|
||||||
img = resize_with_pad(
|
|
||||||
img,
|
|
||||||
*self.config.resize_imgs_with_padding,
|
|
||||||
pad_value=0,
|
|
||||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
|
||||||
img = img * 2.0 - 1.0
|
|
||||||
|
|
||||||
bsize = img.shape[0]
|
|
||||||
device = img.device
|
|
||||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
|
||||||
else:
|
|
||||||
if num_empty_cameras >= self.config.empty_cameras:
|
|
||||||
continue
|
|
||||||
img = torch.ones_like(img) * -1
|
|
||||||
bsize = img.shape[0]
|
|
||||||
device = img.device
|
|
||||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
|
||||||
num_empty_cameras += 1
|
|
||||||
|
|
||||||
images.append(img)
|
|
||||||
img_masks.append(mask)
|
|
||||||
return images, img_masks
|
|
||||||
|
|
||||||
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
|
||||||
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
|
|
||||||
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
|
||||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
|
||||||
|
|
||||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
||||||
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
|
||||||
return out
|
|
||||||
|
|
||||||
def fast_tokenizer_wrapper(self, actions_norm):
|
|
||||||
"""
|
|
||||||
A wrapper for self.fast_tokenizer that ensures batch processing,
|
|
||||||
conversion to PyTorch tensors, and returns a dictionary without padding.
|
|
||||||
"""
|
|
||||||
batch_tokens = self.fast_tokenizer(actions_norm)
|
|
||||||
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
|
|
||||||
|
|
||||||
return fast_out
|
|
||||||
|
|
||||||
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
|
|
||||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
|
||||||
# Compute cumulative sum mask
|
|
||||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
|
||||||
# Suffix block (everything after prefix_len)
|
|
||||||
suffix_mask = cumsum_mask > prefix_len
|
|
||||||
token_type_ids = suffix_mask
|
|
||||||
return token_type_ids
|
|
||||||
|
|
||||||
def create_input_tokens(self, state, lang_text, actions=None):
|
|
||||||
bsize = state.shape[0]
|
|
||||||
device = state.device
|
|
||||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
|
||||||
discretized = torch.bucketize(state, bins) - 1
|
|
||||||
discretized = discretized[:, :32]
|
|
||||||
|
|
||||||
prefix_texts = []
|
|
||||||
state_text = []
|
|
||||||
for txt, disc in zip(lang_text, discretized, strict=False):
|
|
||||||
cleaned = txt.lower().strip().replace("_", " ")
|
|
||||||
state_str = " ".join(str(val.item()) for val in disc)
|
|
||||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
|
||||||
state_text.append(f"State: {state_str};\n")
|
|
||||||
|
|
||||||
prefix_out = self.paligemma_tokenizer(
|
|
||||||
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
|
|
||||||
)
|
|
||||||
prefix_ids = prefix_out["input_ids"].to(device)
|
|
||||||
prefix_mask = prefix_out["attention_mask"].to(device)
|
|
||||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
|
||||||
|
|
||||||
if actions is not None:
|
|
||||||
actions_norm = self.normalize_actions(actions)
|
|
||||||
actions_pad = F.pad(
|
|
||||||
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
|
|
||||||
)[:, :, : self.config.max_action_dim]
|
|
||||||
fast_out = self.fast_tokenizer_wrapper(
|
|
||||||
actions_pad.cpu(),
|
|
||||||
)
|
|
||||||
act_ids = fast_out["input_ids"]
|
|
||||||
act_mask = fast_out["attention_mask"].to(device)
|
|
||||||
|
|
||||||
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
|
||||||
# Replace action with 0 to pad tokens
|
|
||||||
act_ids = torch.where(
|
|
||||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
|
||||||
self.pad_token_id,
|
|
||||||
act_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
eos_token = torch.tensor(
|
|
||||||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
|
||||||
).expand(bsize, -1)
|
|
||||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
|
||||||
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
|
||||||
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
|
||||||
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
|
||||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
|
||||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
|
||||||
act_mask = act_mask.to(device)
|
|
||||||
else:
|
|
||||||
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
|
||||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
|
||||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
|
||||||
|
|
||||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
|
||||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
|
||||||
|
|
||||||
# Use tokenizer pad function
|
|
||||||
padded_output = self.paligemma_tokenizer.pad(
|
|
||||||
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
|
|
||||||
)
|
|
||||||
padded_mask = padded_output["attention_mask"]
|
|
||||||
|
|
||||||
# define tensor of padding lengths
|
|
||||||
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
|
||||||
|
|
||||||
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
|
||||||
|
|
||||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
|
||||||
padded_output["attention_mask"] = att_mask
|
|
||||||
# loss is computed not on prefix, and not on padding
|
|
||||||
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
|
|
||||||
padded_output["token_type_ids"] = token_type_ids
|
|
||||||
return padded_output
|
|
||||||
|
|
||||||
def shift_padding_side(
|
|
||||||
self,
|
|
||||||
tokens: torch.Tensor,
|
|
||||||
ar_mask: torch.Tensor,
|
|
||||||
padding_mask: torch.Tensor,
|
|
||||||
loss_mask: torch.Tensor,
|
|
||||||
targets: torch.Tensor,
|
|
||||||
token_type_ids: torch.Tensor,
|
|
||||||
padding_side: str = "right",
|
|
||||||
) -> tuple[torch.Tensor]:
|
|
||||||
if padding_side not in ["right", "left"]:
|
|
||||||
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
|
|
||||||
|
|
||||||
new_tokens = torch.empty_like(tokens)
|
|
||||||
new_ar_masks = torch.empty_like(ar_mask)
|
|
||||||
new_padding_mask = torch.empty_like(padding_mask)
|
|
||||||
new_loss_mask = torch.empty_like(loss_mask)
|
|
||||||
new_targets = torch.empty_like(targets)
|
|
||||||
new_token_type_ids = torch.empty_like(token_type_ids)
|
|
||||||
batch_size = tokens.shape[0]
|
|
||||||
for i in range(batch_size):
|
|
||||||
padding_indices = torch.where(padding_mask[i] == 0)[0]
|
|
||||||
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
|
|
||||||
if padding_side == "left":
|
|
||||||
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
|
|
||||||
else:
|
|
||||||
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
|
|
||||||
new_tokens[i] = tokens[i].index_select(0, new_indices)
|
|
||||||
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
|
|
||||||
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
|
|
||||||
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
|
|
||||||
new_targets[i] = targets[i].index_select(0, new_indices)
|
|
||||||
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
|
|
||||||
|
|
||||||
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]):
|
|
||||||
device = batch[OBS_STATE].device
|
|
||||||
# TODO: keep like this or move to the policy .forward
|
|
||||||
images, img_masks = self.prepare_images(batch)
|
|
||||||
|
|
||||||
padded_outs = self.create_input_tokens(
|
|
||||||
state=batch[OBS_STATE],
|
|
||||||
lang_text=batch["task"],
|
|
||||||
actions=batch[ACTION],
|
|
||||||
)
|
|
||||||
|
|
||||||
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
|
||||||
images,
|
|
||||||
img_masks,
|
|
||||||
padded_outs["input_ids"],
|
|
||||||
padded_outs["padded_mask"],
|
|
||||||
padded_outs["attention_mask"],
|
|
||||||
padded_outs["loss_mask"],
|
|
||||||
padded_outs["token_type_ids"],
|
|
||||||
padding_side=self.padding_side,
|
|
||||||
)
|
|
||||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
||||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
|
||||||
past_seen_tokens = 0
|
|
||||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
|
|
||||||
pad_masks = block_causal_update_causal_mask(
|
|
||||||
attention_mask=pad_masks,
|
|
||||||
past_key_values=None,
|
|
||||||
cache_position=cache_position,
|
|
||||||
input_tensor=embs,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
dtype=self.pi0_paligemma.dtype,
|
|
||||||
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
|
|
||||||
)
|
|
||||||
outputs = self.pi0_paligemma.forward(
|
|
||||||
input_ids=None,
|
|
||||||
token_type_ids=None,
|
|
||||||
attention_mask=pad_masks,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=embs,
|
|
||||||
use_cache=False,
|
|
||||||
labels=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = outputs.logits
|
|
||||||
|
|
||||||
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
|
||||||
|
|
||||||
# Shift left for next-step prediction
|
|
||||||
logits = logits[:, :-1, :]
|
|
||||||
targets = targets[:, 1:].to(device) # Shift targets
|
|
||||||
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
|
|
||||||
|
|
||||||
# Compute per-token loss
|
|
||||||
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
|
|
||||||
|
|
||||||
# Apply loss mask
|
|
||||||
token_loss = token_loss * loss_mask.reshape(-1)
|
|
||||||
|
|
||||||
# Compute final loss
|
|
||||||
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
|
|
||||||
|
|
||||||
# Return loss dictionary
|
|
||||||
loss_dict = {"ce_loss": loss.item(), "loss": loss}
|
|
||||||
return loss_dict
|
|
||||||
|
|
||||||
def decode_actions_with_fast(
|
|
||||||
self,
|
|
||||||
tokens: list[list[int]],
|
|
||||||
*,
|
|
||||||
time_horizon: int | None = None,
|
|
||||||
action_dim: int | None = None,
|
|
||||||
relaxed_decoding: bool = True,
|
|
||||||
) -> np.array:
|
|
||||||
"""
|
|
||||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
|
||||||
"""
|
|
||||||
self.time_horizon = (
|
|
||||||
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
|
||||||
)
|
|
||||||
self.action_dim = (
|
|
||||||
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cache the time horizon and action dimension for the next call
|
|
||||||
self.called_time_horizon = self.time_horizon
|
|
||||||
self.called_action_dim = self.action_dim
|
|
||||||
|
|
||||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
|
||||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
|
||||||
)
|
|
||||||
|
|
||||||
decoded_actions = []
|
|
||||||
for token in tokens:
|
|
||||||
try:
|
|
||||||
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
|
||||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
|
||||||
if relaxed_decoding:
|
|
||||||
# Expected sequence length
|
|
||||||
expected_seq_len = self.time_horizon * self.action_dim
|
|
||||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
|
||||||
# Apply truncation if too long
|
|
||||||
if diff < 0:
|
|
||||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
|
||||||
# Apply padding if too short
|
|
||||||
elif diff > 0:
|
|
||||||
decoded_dct_coeff = np.pad(
|
|
||||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
|
||||||
)
|
|
||||||
|
|
||||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
|
||||||
assert decoded_dct_coeff.shape == (
|
|
||||||
self.time_horizon,
|
|
||||||
self.action_dim,
|
|
||||||
), (
|
|
||||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error decoding tokens: {e}")
|
|
||||||
print(f"Tokens: {token}")
|
|
||||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
|
||||||
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
|
|
||||||
return np.stack(decoded_actions)
|
|
||||||
|
|
||||||
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Extracts actions from predicted output tokens using the FAST model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens (torch.Tensor): The input tensor of tokenized outputs.
|
|
||||||
action_horizon (int): The number of timesteps for actions.
|
|
||||||
action_dim (int): The dimensionality of each action.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
|
|
||||||
"""
|
|
||||||
# Decode predicted output tokens
|
|
||||||
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
|
||||||
cleaned_tokens = [
|
|
||||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
|
||||||
for tokens_sequence in decoded_tokens
|
|
||||||
]
|
|
||||||
raw_action_tokens = [
|
|
||||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
|
||||||
for sample_tokens in cleaned_tokens
|
|
||||||
] # something like this should be robust #looks good
|
|
||||||
action_tokens = [
|
|
||||||
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
|
||||||
]
|
|
||||||
# returns the tensor of decoded actions per sample in a list
|
|
||||||
decoded_actions = [
|
|
||||||
torch.tensor(
|
|
||||||
self.decode_actions_with_fast(
|
|
||||||
tok.tolist(),
|
|
||||||
time_horizon=action_horizon,
|
|
||||||
action_dim=action_dim,
|
|
||||||
relaxed_decoding=self.config.relaxed_action_decoding,
|
|
||||||
),
|
|
||||||
device=tokens.device,
|
|
||||||
).squeeze(0)
|
|
||||||
for tok in action_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
return torch.stack(
|
|
||||||
decoded_actions,
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]):
|
|
||||||
# TODO: keep like this or move to the policy .forward
|
|
||||||
images, img_masks = self.prepare_images(batch)
|
|
||||||
|
|
||||||
padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None)
|
|
||||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
|
||||||
images,
|
|
||||||
img_masks,
|
|
||||||
padded_outs["input_ids"],
|
|
||||||
padded_outs["padded_mask"],
|
|
||||||
padded_outs["attention_mask"],
|
|
||||||
padded_outs["loss_mask"],
|
|
||||||
padded_outs["token_type_ids"],
|
|
||||||
padding_side="left",
|
|
||||||
)
|
|
||||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
|
||||||
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
||||||
output_tokens = self.pi0_paligemma.generate(
|
|
||||||
input_ids=None,
|
|
||||||
attention_mask=pad_masks,
|
|
||||||
position_ids=prefix_position_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=embs,
|
|
||||||
use_cache=self.config.use_cache,
|
|
||||||
max_new_tokens=self.config.max_decoding_steps,
|
|
||||||
do_sample=False,
|
|
||||||
num_beams=1,
|
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
)
|
|
||||||
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
|
||||||
return actions
|
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
|
||||||
# Handle different transformers versions
|
|
||||||
if hasattr(self.pi0_paligemma, "get_image_features"):
|
|
||||||
return self.pi0_paligemma.get_image_features(image)
|
|
||||||
else:
|
|
||||||
return self.pi0_paligemma.model.get_image_features(image)
|
|
||||||
|
|
||||||
def embed_inputs(
|
|
||||||
self,
|
|
||||||
images,
|
|
||||||
img_masks,
|
|
||||||
tokens,
|
|
||||||
pad_mask,
|
|
||||||
ar_mask,
|
|
||||||
loss_mask,
|
|
||||||
token_type_ids,
|
|
||||||
padding_side: str = "right",
|
|
||||||
):
|
|
||||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
|
||||||
# images are a list of same size
|
|
||||||
# vectorizing everything!
|
|
||||||
device = images[0].device
|
|
||||||
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
|
|
||||||
all_images = torch.stack(images, dim=1).to(device)
|
|
||||||
b, n, c, h, w = all_images.shape
|
|
||||||
all_images = all_images.view(b * n, c, h, w)
|
|
||||||
embedded = self.embed_image(all_images).to(device)
|
|
||||||
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
|
|
||||||
m = b_n // b # Compute the number of images per sample dynamically
|
|
||||||
|
|
||||||
# Reshape dynamically
|
|
||||||
embedded = embedded.view(b, m, p, image_embedding_dim)
|
|
||||||
tokens_embs = self.embed_tokens(tokens.to(device))
|
|
||||||
|
|
||||||
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
|
|
||||||
num_img_emb = embedded.shape[2]
|
|
||||||
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
|
|
||||||
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
|
||||||
|
|
||||||
image_target_tokens = (
|
|
||||||
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
|
|
||||||
).reshape(b, -1)
|
|
||||||
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
|
||||||
|
|
||||||
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
|
|
||||||
|
|
||||||
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
|
|
||||||
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
|
|
||||||
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
|
|
||||||
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
|
|
||||||
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
|
|
||||||
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
|
|
||||||
|
|
||||||
# Shift pad tokens to the left (.generate()) or right (.train())
|
|
||||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
|
|
||||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
|
|
||||||
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
|
|
||||||
|
|
||||||
|
|
||||||
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
|
||||||
# assume no-op when width height fits already
|
|
||||||
if img.ndim != 4:
|
|
||||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
|
||||||
|
|
||||||
cur_height, cur_width = img.shape[2:]
|
|
||||||
|
|
||||||
ratio = max(cur_width / width, cur_height / height)
|
|
||||||
resized_height = int(cur_height / ratio)
|
|
||||||
resized_width = int(cur_width / ratio)
|
|
||||||
|
|
||||||
if interpolate_like_pi:
|
|
||||||
img = (img * 255.0).to(dtype=torch.uint8)
|
|
||||||
img = img.permute(0, 2, 3, 1)
|
|
||||||
original_device = img.device
|
|
||||||
img = img.to(device="cpu").numpy()
|
|
||||||
imgs = []
|
|
||||||
for sub_img in img:
|
|
||||||
sub_img = Image.fromarray(sub_img)
|
|
||||||
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
|
|
||||||
resized_img = torch.from_numpy(np.array(resized_img))
|
|
||||||
imgs.append(resized_img)
|
|
||||||
img = torch.stack(imgs, dim=0)
|
|
||||||
img = img.permute(0, 3, 1, 2)
|
|
||||||
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
|
|
||||||
else:
|
|
||||||
resized_img = F.interpolate(
|
|
||||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
|
|
||||||
pad_height = max(0, int(height - resized_height))
|
|
||||||
pad_width = max(0, int(width - resized_width))
|
|
||||||
|
|
||||||
# pad on left and top of image
|
|
||||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
|
||||||
return padded_img
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
|
||||||
from lerobot.processor import (
|
|
||||||
AddBatchDimensionProcessorStep,
|
|
||||||
DeviceProcessorStep,
|
|
||||||
NormalizerProcessorStep,
|
|
||||||
PolicyAction,
|
|
||||||
PolicyProcessorPipeline,
|
|
||||||
RenameObservationsProcessorStep,
|
|
||||||
UnnormalizerProcessorStep,
|
|
||||||
)
|
|
||||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
|
||||||
|
|
||||||
|
|
||||||
def make_pi0fast_pre_post_processors(
|
|
||||||
config: PI0FASTConfig,
|
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
|
||||||
) -> tuple[
|
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
|
|
||||||
|
|
||||||
The pre-processing pipeline prepares input data for the model by:
|
|
||||||
1. Renaming features to match pretrained configurations.
|
|
||||||
2. Normalizing input and output features based on dataset statistics.
|
|
||||||
3. Adding a batch dimension.
|
|
||||||
4. Moving all data to the specified device.
|
|
||||||
|
|
||||||
The post-processing pipeline handles the model's output by:
|
|
||||||
1. Moving data to the CPU.
|
|
||||||
2. Unnormalizing the output features to their original scale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: The configuration object for the PI0Fast policy.
|
|
||||||
dataset_stats: A dictionary of statistics for normalization.
|
|
||||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
|
||||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_steps = [
|
|
||||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
|
||||||
AddBatchDimensionProcessorStep(),
|
|
||||||
DeviceProcessorStep(device=config.device),
|
|
||||||
NormalizerProcessorStep(
|
|
||||||
features={**config.input_features, **config.output_features},
|
|
||||||
norm_map=config.normalization_mapping,
|
|
||||||
stats=dataset_stats,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
output_steps = [
|
|
||||||
UnnormalizerProcessorStep(
|
|
||||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
|
||||||
),
|
|
||||||
DeviceProcessorStep(device="cpu"),
|
|
||||||
]
|
|
||||||
return (
|
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
|
||||||
steps=input_steps,
|
|
||||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
|
||||||
),
|
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
|
||||||
steps=output_steps,
|
|
||||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
|
||||||
to_transition=policy_action_to_transition,
|
|
||||||
to_output=transition_to_policy_action,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@@ -19,8 +19,6 @@
|
|||||||
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
|
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
|
||||||
{% elif model_name == "vqbet" %}
|
{% elif model_name == "vqbet" %}
|
||||||
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
|
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
|
||||||
{% elif model_name == "pi0fast" %}
|
|
||||||
[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time.
|
|
||||||
{% elif model_name == "pi0" %}
|
{% elif model_name == "pi0" %}
|
||||||
**π₀ (Pi0)**
|
**π₀ (Pi0)**
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user