@@ -0,0 +1,973 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
[Paper](https://arxiv.org/abs/2501.09747)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/pi0fast_base \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of training the pi0+FAST neural network with from scratch:
```bash
python lerobot/scripts/train.py \
--policy.type=pi0fast \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = PI0FASTPolicy.from_pretrained( " lerobot/pi0fast_base " )
```
"""
from collections import deque
from functools import partial
import numpy as np
import torch
import torch . nn . functional as F # noqa: N812
from PIL import Image
from scipy . fft import idct
from torch import Tensor , nn
from transformers import AutoProcessor , AutoTokenizer , PaliGemmaForConditionalGeneration
from transformers . cache_utils import HybridCache , StaticCache
from transformers . models . auto import CONFIG_MAPPING
from lerobot . common . constants import ACTION , OBS_ROBOT
from lerobot . common . policies . normalize import Normalize , Unnormalize
from lerobot . common . policies . pi0fast . configuration_pi0fast import PI0FASTConfig
from lerobot . common . policies . pretrained import PreTrainedPolicy
PRECISION = {
" float16 " : torch . float16 ,
" float32 " : torch . float32 ,
" bfloat16 " : torch . bfloat16 ,
}
def normalize ( x , min_val , max_val ) :
return ( x - min_val ) / ( max_val - min_val )
def unnormalize ( x , min_val , max_val ) :
return x * ( max_val - min_val ) + min_val
def safe_arcsin ( value ) :
# This ensures that the input stays within
# [− 1,1] to avoid invalid values for arcsin
return torch . arcsin ( torch . clamp ( value , - 1.0 , 1.0 ) )
def aloha_gripper_to_angular ( value ) :
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize ( value , min_val = 0.01844 , max_val = 0.05800 )
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian ( linear_position , arm_length , horn_radius ) :
value = ( horn_radius * * 2 + linear_position * * 2 - arm_length * * 2 ) / ( 2 * horn_radius * linear_position )
return safe_arcsin ( value )
# The constants are taken from the Interbotix code.
value = linear_to_radian ( value , arm_length = 0.036 , horn_radius = 0.022 )
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize ( value , min_val = 0.4 , max_val = 1.5 )
def aloha_gripper_from_angular ( value ) :
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize ( value , min_val = 0.4 , max_val = 1.5 )
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize ( value , min_val = - 0.6213 , max_val = 1.4910 )
def aloha_gripper_from_angular_inv ( value ) :
# Directly inverts the gripper_from_angular function.
value = unnormalize ( value , min_val = - 0.6213 , max_val = 1.4910 )
return normalize ( value , min_val = 0.4 , max_val = 1.5 )
class PI0FASTPolicy ( PreTrainedPolicy ) :
""" Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot. """
config_class = PI0FASTConfig
name = " pi0fast "
def __init__ (
self ,
config : PI0FASTConfig ,
dataset_stats : dict [ str , dict [ str , Tensor ] ] | None = None ,
) :
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super ( ) . __init__ ( config )
config . validate_features ( )
self . config = config
self . normalize_inputs = Normalize ( config . input_features , config . normalization_mapping , dataset_stats )
self . normalize_targets = Normalize (
config . output_features , config . normalization_mapping , dataset_stats
)
self . unnormalize_outputs = Unnormalize (
config . output_features , config . normalization_mapping , dataset_stats
)
self . language_tokenizer = AutoProcessor . from_pretrained ( " google/paligemma-3b-pt-224 " )
self . model = PI0FAST ( config )
self . reset ( )
def reset ( self ) :
""" This should be called whenever the environment is reset. """
self . _action_queue = deque ( [ ] , maxlen = self . config . n_action_steps )
def get_optim_params ( self ) - > dict :
return self . parameters ( )
def _pi_aloha_decode_state ( self , state ) :
# Flip the joints.
for motor_idx in [ 1 , 2 , 8 , 9 ] :
state [ : , motor_idx ] * = - 1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [ 6 , 13 ] :
state [ : , motor_idx ] = aloha_gripper_to_angular ( state [ : , motor_idx ] )
return state
def _pi_aloha_encode_actions ( self , actions ) :
# Flip the joints.
for motor_idx in [ 1 , 2 , 8 , 9 ] :
actions [ : , : , motor_idx ] * = - 1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [ 6 , 13 ] :
actions [ : , : , motor_idx ] = aloha_gripper_from_angular ( actions [ : , : , motor_idx ] )
return actions
def _pi_aloha_encode_actions_inv ( self , actions ) :
# Flip the joints again.
for motor_idx in [ 1 , 2 , 8 , 9 ] :
actions [ : , : , motor_idx ] * = - 1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [ 6 , 13 ] :
actions [ : , : , motor_idx ] = aloha_gripper_from_angular_inv ( actions [ : , : , motor_idx ] )
return actions
@torch.no_grad
def select_action ( self , batch : dict [ str , Tensor ] ) - > Tensor :
""" Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self . eval ( )
if self . config . adapt_to_pi_aloha :
batch [ OBS_ROBOT ] = self . _pi_aloha_decode_state ( batch [ OBS_ROBOT ] )
batch = self . normalize_inputs ( batch )
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len ( self . _action_queue ) == 0 :
actions = self . model . generate_actions ( batch )
actions = actions [ : , : self . config . n_action_steps ]
original_action_dim = self . config . action_feature . shape [
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions [ : , : , : original_action_dim ]
actions = self . unnormalize_outputs ( { " action " : actions } ) [ " action " ]
if self . config . adapt_to_pi_aloha :
actions = self . _pi_aloha_encode_actions ( actions )
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self . _action_queue . extend ( actions . transpose ( 0 , 1 ) )
return self . _action_queue . popleft ( )
def forward ( self , batch : dict [ str , Tensor ] ) - > dict [ str , Tensor ] :
if self . config . adapt_to_pi_aloha :
batch [ OBS_ROBOT ] = self . _pi_aloha_decode_state ( batch [ OBS_ROBOT ] )
batch [ ACTION ] = self . _pi_aloha_encode_actions_inv ( batch [ ACTION ] )
batch = self . normalize_inputs ( batch )
batch = self . normalize_targets ( batch )
loss_dict = self . model . forward ( batch )
return loss_dict [ " loss " ] , loss_dict
def block_causal_update_causal_mask (
attention_mask ,
token_type_ids = None ,
past_key_values = None ,
cache_position = None ,
input_tensor = None ,
attn_implementation : str = " eager " ,
dtype : torch . dtype = " float32 " ,
) :
"""
Update the causal mask during training and generation. It can be customized to different attention masks.
"""
if attn_implementation == " flash_attention_2 " :
if attention_mask is not None and 0.0 in attention_mask :
return attention_mask
return None
using_static_cache = isinstance ( past_key_values , StaticCache )
min_dtype = torch . finfo ( dtype ) . min
if input_tensor is None :
input_tensor = attention_mask
inputs_lead_dim , sequence_length = input_tensor . shape [ : 2 ]
if using_static_cache or isinstance ( past_key_values , HybridCache ) :
target_length = past_key_values . get_max_cache_shape ( )
else :
target_length = (
attention_mask . shape [ - 1 ]
if isinstance ( attention_mask , torch . Tensor )
else cache_position [ 0 ] + sequence_length + 1
)
# Handle precomputed attention masks
if attention_mask is not None and attention_mask . dim ( ) == 4 :
return attention_mask
# Causal mask initialization
causal_mask = torch . full (
( sequence_length , target_length ) , fill_value = min_dtype , dtype = dtype , device = cache_position . device
)
# Standard causal masking (triu ensures tokens can only attend to past)
if sequence_length != 1 :
causal_mask = torch . triu ( causal_mask , diagonal = 1 )
# Apply block causal mask
if token_type_ids is not None :
token_type_ids = token_type_ids . to ( causal_mask . device ) . bool ( )
cumsum = torch . cumsum ( token_type_ids , dim = 1 )
block_causal_mask = cumsum [ : , None , : ] < = cumsum [ : , : , None ]
# Combine causal_mask with block-wise attention mask
causal_mask = torch . where ( block_causal_mask , 0.0 , causal_mask )
causal_mask = causal_mask [ : , None , : , : ]
else :
# Apply past cache position constraint
causal_mask * = torch . arange ( target_length , device = cache_position . device ) > cache_position . reshape (
- 1 , 1
)
causal_mask = causal_mask [ None , None , : , : ] . expand ( inputs_lead_dim , 1 , - 1 , - 1 )
else :
# Apply past cache position constraint
causal_mask * = torch . arange ( target_length , device = cache_position . device ) > cache_position . reshape (
- 1 , 1
)
causal_mask = causal_mask [ None , None , : , : ] . expand ( inputs_lead_dim , 1 , - 1 , - 1 )
if attention_mask is not None :
causal_mask = causal_mask . clone ( ) # Copy to contiguous memory for in-place edits
mask_length = attention_mask . shape [ - 1 ]
# Apply padding mask
padding_mask = causal_mask [ : , : , : , : mask_length ] + attention_mask [ : , None , None , : ] . to (
causal_mask . device
)
padding_mask = padding_mask == 0
causal_mask [ : , : , : , : mask_length ] = causal_mask [ : , : , : , : mask_length ] . masked_fill (
padding_mask , min_dtype
)
return causal_mask
def prepare_inputs_for_generation (
# self,
input_ids ,
past_key_values = None ,
inputs_embeds = None ,
cache_position = None ,
position_ids = None ,
pixel_values = None ,
attention_mask = None ,
token_type_ids = None ,
use_cache = True ,
num_logits_to_keep = None ,
labels = None ,
self = None ,
* * kwargs ,
) :
# create block causal attention
if cache_position [ 0 ] > 0 and input_ids . shape [ 1 ] > 0 :
input_tensor = input_ids [ : , - 1 : ]
new_positions = (
torch . ones (
( position_ids . shape [ 0 ] , input_ids . shape [ 1 ] ) ,
dtype = position_ids . dtype ,
device = position_ids . device ,
) . cumsum ( - 1 )
+ position_ids [ : , - 1 : ]
)
position_ids = torch . cat ( [ position_ids , new_positions ] , dim = - 1 )
else :
input_tensor = inputs_embeds
attention_mask = block_causal_update_causal_mask (
attention_mask = attention_mask ,
past_key_values = past_key_values ,
cache_position = cache_position ,
input_tensor = input_tensor ,
token_type_ids = token_type_ids ,
dtype = self . dtype ,
attn_implementation = self . config . text_config . _attn_implementation ,
)
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self . language_model . prepare_inputs_for_generation (
input_ids ,
past_key_values = past_key_values ,
inputs_embeds = inputs_embeds ,
attention_mask = attention_mask ,
position_ids = position_ids ,
cache_position = cache_position ,
use_cache = use_cache ,
num_logits_to_keep = num_logits_to_keep ,
token_type_ids = token_type_ids ,
* * kwargs ,
)
# Position_ids in Paligemma are 1-indexed
if model_inputs . get ( " position_ids " ) is not None :
model_inputs [ " position_ids " ] + = 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position [ 0 ] == 0 :
model_inputs [ " pixel_values " ] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position [ 0 ] == 0 and isinstance ( past_key_values , HybridCache ) :
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self . _update_causal_mask (
attention_mask , token_type_ids , past_key_values , cache_position , input_tensor , is_training
)
model_inputs [ " attention_mask " ] = causal_mask
return model_inputs
class PI0FAST ( nn . Module ) :
def __init__ ( self , config : PI0FASTConfig ) :
super ( ) . __init__ ( )
self . config = config
# TODO: move tokenizers in Policy
fast_tokenizer_path = " physical-intelligence/fast "
pi0_paligemma_path = " google/paligemma-3b-pt-224 "
self . paligemma_tokenizer = AutoTokenizer . from_pretrained ( pi0_paligemma_path )
self . processor = AutoProcessor . from_pretrained ( pi0_paligemma_path )
self . fast_tokenizer = AutoProcessor . from_pretrained ( fast_tokenizer_path , trust_remote_code = True )
self . fast_skip_tokens = self . config . fast_skip_tokens
self . max_input_seq_len = self . config . max_input_seq_len
self . action_horizon = self . config . chunk_size
self . action_dim = self . config . action_feature . shape [
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
precision = config . precision
torch_precision = PRECISION . get ( precision , torch . float32 )
self . pad_token_id = (
self . paligemma_tokenizer . pad_token_id
if hasattr ( self . paligemma_tokenizer , " pad_token_id " )
else self . paligemma_tokenizer . eos_token_id
)
paligemma_config = CONFIG_MAPPING [ " paligemma " ] (
transformers_version = " 4.48.1 " ,
_vocab_size = 257152 ,
bos_token_id = 2 ,
eos_token_id = 1 ,
hidden_size = 2048 ,
image_token_index = 257152 ,
model_type = " paligemma " ,
pad_token_id = 0 ,
projection_dim = 2048 ,
text_config = {
" hidden_activation " : " gelu_pytorch_tanh " ,
" hidden_size " : 2048 ,
" intermediate_size " : 16384 ,
" model_type " : " gemma " ,
" num_attention_heads " : 8 ,
" num_hidden_layers " : 18 ,
" num_image_tokens " : 256 ,
" num_key_value_heads " : 1 ,
" torch_dtype " : precision ,
" vocab_size " : 257152 ,
" _attn_implementation " : " eager " ,
} ,
vision_config = {
" hidden_size " : 1152 ,
" intermediate_size " : 4304 ,
" model_type " : " siglip_vision_model " ,
" num_attention_heads " : 16 ,
" num_hidden_layers " : 27 ,
" num_image_tokens " : 256 ,
" patch_size " : 14 ,
" projection_dim " : 2048 ,
" projector_hidden_act " : " gelu_pytorch_tanh " ,
" torch_dtype " : precision ,
" vision_use_head " : False ,
} ,
)
self . pi0_paligemma = PaliGemmaForConditionalGeneration ( config = paligemma_config )
self . pi0_paligemma . prepare_inputs_for_generation = partial (
prepare_inputs_for_generation , self = self . pi0_paligemma
)
# change important stuff in bf16
params_to_change_dtype = [
" language_model " ,
" vision_tower " ,
" multi_modal " ,
]
for name , param in self . pi0_paligemma . named_parameters ( ) :
if any ( selector in name for selector in params_to_change_dtype ) :
param . data = param . data . to ( dtype = torch_precision )
self . set_requires_grad ( )
self . image_keys = self . config . image_features . keys ( )
self . ignore_index = self . pi0_paligemma . config . ignore_index
self . padding_side = self . config . padding_side
def set_requires_grad ( self ) :
if self . config . freeze_vision_encoder :
self . pi0_paligemma . vision_tower . eval ( )
for params in self . pi0_paligemma . vision_tower . parameters ( ) :
params . requires_grad = False
# To avoid unused params issue with distributed training
if self . config . freeze_lm_head :
for name , params in self . pi0_paligemma . named_parameters ( ) :
if " embed_tokens " in name : # lm heads and embedding layer are tied
params . requires_grad = False
def embed_tokens ( self , tokens : torch . Tensor ) :
return self . pi0_paligemma . language_model . model . embed_tokens ( tokens )
def prepare_inputs_for_generation ( self , * args , * * kwargs ) :
return self . pi0_paligemma . prepare_inputs_for_generation ( * args , * * kwargs )
def prepare_images ( self , batch ) :
""" Preprocess LeRobot batch into Pi0 inputs """
images = [ ]
img_masks = [ ]
present_img_keys = [ key for key in self . image_keys if key in batch ]
if len ( present_img_keys ) == 0 :
raise ValueError (
f " All image features are missing from the batch. At least one expected. (batch: { batch . keys ( ) } ) (image_features: { self . config . image_features } ) "
)
# Preprocess image features present in the batch
num_empty_cameras = 0
for key in self . image_keys :
if key in present_img_keys :
img = batch [ key ]
if self . config . resize_imgs_with_padding is not None :
img = resize_with_pad (
img ,
* self . config . resize_imgs_with_padding ,
pad_value = 0 ,
interpolate_like_pi = self . config . interpolate_like_pi ,
)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
bsize = img . shape [ 0 ]
device = img . device
mask = torch . ones ( bsize , dtype = torch . bool , device = device )
else :
if num_empty_cameras > = self . config . empty_cameras :
continue
img = torch . ones_like ( img ) * - 1
bsize = img . shape [ 0 ]
device = img . device
mask = torch . ones ( bsize , dtype = torch . bool , device = device )
num_empty_cameras + = 1
images . append ( img )
img_masks . append ( mask )
return images , img_masks
def normalize_actions ( self , actions : torch . Tensor ) - > torch . Tensor :
mins = actions . amin ( dim = ( 1 , 2 ) , keepdim = True ) # [0]
maxs = actions . amax ( dim = ( 1 , 2 ) , keepdim = True ) # [0]
return 2 * ( actions - mins ) / ( maxs - mins + 1e-8 ) - 1
def _act_tokens_to_paligemma_tokens ( self , tokens : torch . Tensor ) - > torch . Tensor :
out = self . paligemma_tokenizer . vocab_size - 1 - self . fast_skip_tokens - tokens
return out
def fast_tokenizer_wrapper ( self , actions_norm ) :
"""
A wrapper for self.fast_tokenizer that ensures batch processing,
conversion to PyTorch tensors, and returns a dictionary without padding.
"""
batch_tokens = self . fast_tokenizer ( actions_norm )
fast_out = self . processor . tokenizer . pad ( { " input_ids " : batch_tokens } , return_tensors = " pt " )
return fast_out
def create_token_type_ids ( self , padded_mask : torch . Tensor , prefix_len : int ) - > torch . Tensor :
token_type_ids = torch . zeros_like ( padded_mask , dtype = torch . bool )
# Compute cumulative sum mask
cumsum_mask = ( padded_mask != 0 ) . cumsum ( dim = 1 )
# Suffix block (everything after prefix_len)
suffix_mask = cumsum_mask > prefix_len
token_type_ids = suffix_mask
return token_type_ids
def create_input_tokens ( self , state , lang_text , actions = None ) :
bsize = state . shape [ 0 ]
device = state . device
bins = torch . linspace ( - 1 , 1 , 256 + 1 , device = device ) [ : - 1 ]
discretized = torch . bucketize ( state , bins ) - 1
discretized = discretized [ : , : 32 ]
prefix_texts = [ ]
state_text = [ ]
for txt , disc in zip ( lang_text , discretized , strict = False ) :
cleaned = txt . lower ( ) . strip ( ) . replace ( " _ " , " " )
state_str = " " . join ( str ( val . item ( ) ) for val in disc )
prefix_texts . append ( f " Task: { cleaned } , State: { state_str } ; \n " )
state_text . append ( f " State: { state_str } ; \n " )
prefix_out = self . paligemma_tokenizer (
prefix_texts , add_special_tokens = True , return_tensors = " pt " , padding = " longest " , truncation = False
)
prefix_ids = prefix_out [ " input_ids " ] . to ( device )
prefix_mask = prefix_out [ " attention_mask " ] . to ( device )
prefix_lens = prefix_mask . sum ( dim = 1 ) [ : , None ] . cpu ( )
if actions is not None :
actions_norm = self . normalize_actions ( actions )
actions_pad = F . pad (
actions_norm , ( 0 , max ( 0 , self . config . max_action_dim - actions_norm . shape [ 2 ] ) ) , value = 0
) [ : , : , : self . config . max_action_dim ]
fast_out = self . fast_tokenizer_wrapper (
actions_pad . cpu ( ) ,
)
act_ids = fast_out [ " input_ids " ]
act_mask = fast_out [ " attention_mask " ] . to ( device )
act_ids = self . _act_tokens_to_paligemma_tokens ( act_ids ) . to ( device )
# Replace action with 0 to pad tokens
act_ids = torch . where (
act_ids == self . paligemma_tokenizer . vocab_size - 1 - self . fast_skip_tokens ,
self . pad_token_id ,
act_ids ,
)
eos_token = torch . tensor (
[ self . paligemma_tokenizer . eos_token_id ] , dtype = torch . long , device = device
) . expand ( bsize , - 1 )
eos_mask = torch . tensor ( [ 1 ] , dtype = torch . long , device = device ) . expand ( bsize , - 1 )
bos = self . paligemma_tokenizer ( " Action: " , add_special_tokens = False , return_tensors = " pt " )
bos_token = bos [ " input_ids " ] . expand ( act_ids . shape [ 0 ] , - 1 ) . to ( device )
bos_mask = bos [ " attention_mask " ] . expand ( act_ids . shape [ 0 ] , - 1 ) . to ( device )
act_ids = torch . cat ( [ bos_token , act_ids , eos_token ] , dim = 1 )
act_mask = torch . cat ( [ bos_mask , act_mask , eos_mask ] , dim = 1 )
act_mask = act_mask . to ( device )
else :
act_ids = torch . empty ( bsize , self . pad_token_id , dtype = torch . long , device = device )
act_mask = torch . empty ( bsize , 0 , dtype = torch . long , device = device )
final_ids = torch . cat ( [ prefix_ids , act_ids ] , dim = 1 )
final_mask = torch . cat ( [ prefix_mask , act_mask ] , dim = 1 )
batch_inputs = { " input_ids " : final_ids . tolist ( ) , " attention_mask " : final_mask . tolist ( ) }
# Use tokenizer pad function
padded_output = self . paligemma_tokenizer . pad (
batch_inputs , padding = " longest " , max_length = 180 , return_tensors = " pt "
)
padded_mask = padded_output [ " attention_mask " ]
# define tensor of padding lengths
att_mask = ( padded_mask != 0 ) . cumsum ( dim = 1 ) > prefix_lens
token_type_ids = self . create_token_type_ids ( padded_mask = padded_mask , prefix_len = prefix_lens )
padded_output [ " padded_mask " ] = padded_output . pop ( " attention_mask " )
padded_output [ " attention_mask " ] = att_mask
# loss is computed not on prefix, and not on padding
padded_output [ " loss_mask " ] = att_mask & padded_output [ " padded_mask " ]
padded_output [ " token_type_ids " ] = token_type_ids
return padded_output
def shift_padding_side (
self ,
tokens : torch . Tensor ,
ar_mask : torch . Tensor ,
padding_mask : torch . Tensor ,
loss_mask : torch . Tensor ,
targets : torch . Tensor ,
token_type_ids : torch . Tensor ,
padding_side : str = " right " ,
) - > tuple [ torch . Tensor ] :
if padding_side not in [ " right " , " left " ] :
return tokens , ar_mask , padding_mask , loss_mask , targets , token_type_ids
new_tokens = torch . empty_like ( tokens )
new_ar_masks = torch . empty_like ( ar_mask )
new_padding_mask = torch . empty_like ( padding_mask )
new_loss_mask = torch . empty_like ( loss_mask )
new_targets = torch . empty_like ( targets )
new_token_type_ids = torch . empty_like ( token_type_ids )
batch_size = tokens . shape [ 0 ]
for i in range ( batch_size ) :
padding_indices = torch . where ( padding_mask [ i ] == 0 ) [ 0 ]
non_padding_indices = torch . where ( padding_mask [ i ] == 1 ) [ 0 ]
if padding_side == " left " :
new_indices = torch . cat ( ( padding_indices , non_padding_indices ) , dim = 0 )
else :
new_indices = torch . cat ( ( non_padding_indices , padding_indices ) , dim = 0 )
new_tokens [ i ] = tokens [ i ] . index_select ( 0 , new_indices )
new_ar_masks [ i ] = ar_mask [ i ] . index_select ( 0 , new_indices )
new_padding_mask [ i ] = padding_mask [ i ] . index_select ( 0 , new_indices )
new_loss_mask [ i ] = loss_mask [ i ] . index_select ( 0 , new_indices )
new_targets [ i ] = targets [ i ] . index_select ( 0 , new_indices )
new_token_type_ids [ i ] = token_type_ids [ i ] . index_select ( 0 , new_indices )
return new_tokens , new_ar_masks , new_padding_mask , new_loss_mask , new_targets , new_token_type_ids
def forward ( self , batch : dict [ str , Tensor ] ) :
device = batch [ OBS_ROBOT ] . device
# TODO: keep like this or move to the policy .forward
images , img_masks = self . prepare_images ( batch )
padded_outs = self . create_input_tokens (
state = batch [ OBS_ROBOT ] ,
lang_text = batch [ " task " ] ,
actions = batch [ ACTION ] ,
)
embs , pad_masks , _ , targets , loss_mask , token_type_ids = self . embed_inputs (
images ,
img_masks ,
padded_outs [ " input_ids " ] ,
padded_outs [ " padded_mask " ] ,
padded_outs [ " attention_mask " ] ,
padded_outs [ " loss_mask " ] ,
padded_outs [ " token_type_ids " ] ,
padding_side = self . padding_side ,
)
position_ids = torch . cumsum ( pad_masks , dim = 1 ) - 1
token_type_ids = token_type_ids . to ( dtype = torch . int64 )
past_seen_tokens = 0
cache_position = torch . arange ( past_seen_tokens , past_seen_tokens + embs . shape [ 1 ] , device = embs . device )
pad_masks = block_causal_update_causal_mask (
attention_mask = pad_masks ,
past_key_values = None ,
cache_position = cache_position ,
input_tensor = embs ,
token_type_ids = token_type_ids ,
dtype = self . pi0_paligemma . dtype ,
attn_implementation = self . pi0_paligemma . config . text_config . _attn_implementation ,
)
outputs = self . pi0_paligemma . forward (
input_ids = None ,
token_type_ids = None ,
attention_mask = pad_masks ,
position_ids = position_ids ,
past_key_values = None ,
inputs_embeds = embs ,
use_cache = False ,
labels = None ,
)
logits = outputs . logits
loss_fct = nn . CrossEntropyLoss ( reduction = " none " )
# Shift left for next-step prediction
logits = logits [ : , : - 1 , : ]
targets = targets [ : , 1 : ] . to ( device ) # Shift targets
loss_mask = loss_mask [ : , 1 : ] . to ( device ) # Ensure correct shape
# Compute per-token loss
token_loss = loss_fct ( logits . reshape ( - 1 , logits . shape [ - 1 ] ) , targets . reshape ( - 1 ) )
# Apply loss mask
token_loss = token_loss * loss_mask . reshape ( - 1 )
# Compute final loss
loss = token_loss . sum ( ) / torch . clamp ( loss_mask . sum ( ) , min = 1 )
# Return loss dictionary
loss_dict = { " ce_loss " : loss . item ( ) , " loss " : loss }
return loss_dict
def decode_actions_with_fast (
self ,
tokens : list [ list [ int ] ] ,
* ,
time_horizon : int | None = None ,
action_dim : int | None = None ,
relaxed_decoding : bool = True ,
) - > np . array :
"""
Adapt original decoding in FAST to always return actions instead of zeros.
"""
self . time_horizon = (
time_horizon or self . fast_tokenizer . time_horizon or self . fast_tokenizer . called_time_horizon
)
self . action_dim = (
action_dim or self . fast_tokenizer . action_dim or self . fast_tokenizer . called_action_dim
)
# Cache the time horizon and action dimension for the next call
self . called_time_horizon = self . time_horizon
self . called_action_dim = self . action_dim
assert self . time_horizon is not None and self . action_dim is not None , (
" Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim. "
)
decoded_actions = [ ]
for token in tokens :
try :
decoded_tokens = self . fast_tokenizer . bpe_tokenizer . decode ( token )
decoded_dct_coeff = np . array ( list ( map ( ord , decoded_tokens ) ) ) + self . fast_tokenizer . min_token
if relaxed_decoding :
# Expected sequence length
expected_seq_len = self . time_horizon * self . action_dim
diff = expected_seq_len - decoded_dct_coeff . shape [ 0 ]
# Apply truncation if too long
if diff < 0 :
decoded_dct_coeff = decoded_dct_coeff [ : expected_seq_len ] # Truncate on the right
# Apply padding if too short
elif diff > 0 :
decoded_dct_coeff = np . pad (
decoded_dct_coeff , ( 0 , diff ) , mode = " constant " , constant_values = 0
)
decoded_dct_coeff = decoded_dct_coeff . reshape ( - 1 , self . action_dim )
assert decoded_dct_coeff . shape == (
self . time_horizon ,
self . action_dim ,
) , (
f " Decoded DCT coefficients have shape { decoded_dct_coeff . shape } , expected ( { self . time_horizon } , { self . action_dim } ) "
)
except Exception as e :
print ( f " Error decoding tokens: { e } " )
print ( f " Tokens: { token } " )
decoded_dct_coeff = np . zeros ( ( self . time_horizon , self . action_dim ) )
decoded_actions . append ( idct ( decoded_dct_coeff / self . fast_tokenizer . scale , axis = 0 , norm = " ortho " ) )
return np . stack ( decoded_actions )
def extract_actions ( self , tokens : torch . Tensor , action_horizon : int , action_dim : int ) - > torch . Tensor :
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens (torch.Tensor): The input tensor of tokenized outputs.
action_horizon (int): The number of timesteps for actions.
action_dim (int): The dimensionality of each action.
Returns:
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
"""
# Decode predicted output tokens
decoded_tokens = self . paligemma_tokenizer . batch_decode ( tokens , skip_special_tokens = True )
cleaned_tokens = [
tokens_sequence . replace ( " Action: " , " " ) . replace ( " : " , " " ) . strip ( ) . split ( " | " ) [ 0 ] . strip ( )
for tokens_sequence in decoded_tokens
]
raw_action_tokens = [
self . processor . tokenizer . encode ( sample_tokens , return_tensors = " pt " , padding = False )
for sample_tokens in cleaned_tokens
] # something like this should be robust #looks good
action_tokens = [
self . _act_tokens_to_paligemma_tokens ( raw_action_token ) for raw_action_token in raw_action_tokens
]
# returns the tensor of decoded actions per sample in a list
decoded_actions = [
torch . tensor (
self . decode_actions_with_fast (
tok . tolist ( ) ,
time_horizon = action_horizon ,
action_dim = action_dim ,
relaxed_decoding = self . config . relaxed_action_decoding ,
) ,
device = tokens . device ,
) . squeeze ( 0 )
for tok in action_tokens
]
return torch . stack (
decoded_actions ,
dim = 0 ,
)
def generate_actions ( self , batch : dict [ str , Tensor ] ) :
# TODO: keep like this or move to the policy .forward
images , img_masks = self . prepare_images ( batch )
padded_outs = self . create_input_tokens ( state = batch [ OBS_ROBOT ] , lang_text = batch [ " task " ] , actions = None )
embs , pad_masks , att_masks2 , targets , loss_mask , token_type_ids = self . embed_inputs (
images ,
img_masks ,
padded_outs [ " input_ids " ] ,
padded_outs [ " padded_mask " ] ,
padded_outs [ " attention_mask " ] ,
padded_outs [ " loss_mask " ] ,
padded_outs [ " token_type_ids " ] ,
padding_side = " left " ,
)
token_type_ids = token_type_ids . to ( dtype = torch . int64 )
prefix_position_ids = torch . cumsum ( pad_masks , dim = 1 ) - 1
output_tokens = self . pi0_paligemma . generate (
input_ids = None ,
attention_mask = pad_masks ,
position_ids = prefix_position_ids ,
past_key_values = None ,
inputs_embeds = embs ,
use_cache = self . config . use_cache ,
max_new_tokens = self . config . max_decoding_steps ,
do_sample = False ,
num_beams = 1 ,
token_type_ids = token_type_ids ,
)
actions = self . extract_actions ( output_tokens , self . action_horizon , self . action_dim )
return actions
def embed_image ( self , image : torch . Tensor ) :
return self . pi0_paligemma . get_image_features ( image )
def embed_inputs (
self ,
images ,
img_masks ,
tokens ,
pad_mask ,
ar_mask ,
loss_mask ,
token_type_ids ,
padding_side : str = " right " ,
) :
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
# images are a list of same size
# vectorizing everything!
device = images [ 0 ] . device
image_embedding_dim = images [ 0 ] . shape [ - 1 ] # TODO should be from self.config
all_images = torch . stack ( images , dim = 1 ) . to ( device )
b , n , c , h , w = all_images . shape
all_images = all_images . view ( b * n , c , h , w )
embedded = self . embed_image ( all_images ) . to ( device )
b_n , p , image_embedding_dim = embedded . shape # Extract current dimensions
m = b_n / / b # Compute the number of images per sample dynamically
# Reshape dynamically
embedded = embedded . view ( b , m , p , image_embedding_dim )
tokens_embs = self . embed_tokens ( tokens . to ( device ) )
img_masks = torch . stack ( img_masks , dim = 1 ) . unsqueeze ( - 1 ) . to ( device )
num_img_emb = embedded . shape [ 2 ]
img_pad_masks = img_masks . repeat ( 1 , 1 , num_img_emb ) . view ( b , - 1 )
img_att_masks = torch . zeros ( ( b , n , num_img_emb ) , dtype = torch . long , device = device ) . reshape ( b , - 1 )
image_target_tokens = (
torch . ones ( ( b , n , num_img_emb ) , dtype = torch . long , device = device ) * self . pad_token_id
) . reshape ( b , - 1 )
image_loss_mask = torch . zeros ( ( b , n , num_img_emb ) , dtype = torch . long , device = device ) . reshape ( b , - 1 )
embedded = embedded . reshape ( b , n * num_img_emb , image_embedding_dim ) # Shape: (B, N*P, D)
embs = torch . cat ( [ embedded , tokens_embs ] , dim = 1 ) . to ( device )
pad_masks = torch . cat ( [ img_pad_masks , pad_mask . to ( device ) ] , dim = 1 )
att_masks = torch . cat ( [ img_att_masks , ar_mask . to ( device ) ] , dim = 1 )
loss_masks = torch . cat ( [ image_loss_mask , loss_mask . to ( device ) ] , dim = 1 )
targets = torch . cat ( [ image_target_tokens , tokens . to ( device ) ] , dim = 1 )
token_type_ids = torch . cat ( [ img_att_masks , token_type_ids . to ( device ) ] , dim = 1 )
# Shift pad tokens to the left (.generate()) or right (.train())
embs , att_masks , pad_masks , loss_masks , targets , token_type_ids = self . shift_padding_side (
embs , att_masks , pad_masks , loss_masks , targets , token_type_ids , padding_side = padding_side
)
targets = torch . where ( targets == self . pad_token_id , self . ignore_index , targets )
return embs , pad_masks , att_masks , targets , loss_masks , token_type_ids
def resize_with_pad ( img , width , height , pad_value = 0 , interpolate_like_pi = True ) :
# assume no-op when width height fits already
if img . ndim != 4 :
raise ValueError ( f " (b,c,h,w) expected, but { img . shape } " )
cur_height , cur_width = img . shape [ 2 : ]
ratio = max ( cur_width / width , cur_height / height )
resized_height = int ( cur_height / ratio )
resized_width = int ( cur_width / ratio )
if interpolate_like_pi :
img = ( img * 255.0 ) . to ( dtype = torch . uint8 )
img = img . permute ( 0 , 2 , 3 , 1 )
original_device = img . device
img = img . to ( device = " cpu " ) . numpy ( )
imgs = [ ]
for sub_img in img :
sub_img = Image . fromarray ( sub_img )
resized_img = sub_img . resize ( ( resized_width , resized_height ) , resample = 2 )
resized_img = torch . from_numpy ( np . array ( resized_img ) )
imgs . append ( resized_img )
img = torch . stack ( imgs , dim = 0 )
img = img . permute ( 0 , 3 , 1 , 2 )
resized_img = img . to ( device = original_device , dtype = torch . float32 ) / 255.0
else :
resized_img = F . interpolate (
img , size = ( resized_height , resized_width ) , mode = " bilinear " , align_corners = False
)
pad_height = max ( 0 , int ( height - resized_height ) )
pad_width = max ( 0 , int ( width - resized_width ) )
# pad on left and top of image
padded_img = F . pad ( resized_img , ( pad_width , 0 , pad_height , 0 ) , value = pad_value )
return padded_img