add DROID policies
This commit is contained in:
466
src/openpi/models/fsq_tokenizer_v2.py
Normal file
466
src/openpi/models/fsq_tokenizer_v2.py
Normal file
@@ -0,0 +1,466 @@
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import chex
|
||||
from einops import einops
|
||||
from flax import linen as nn
|
||||
from flax.linen.module import Module
|
||||
from flax.linen.module import compact
|
||||
from flax.struct import dataclass
|
||||
from flax.typing import Array
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class FsqCodebook(nn.Module):
|
||||
input_dim: int
|
||||
target_codebook_size: int
|
||||
codebook_type: Literal["fsq", "lfq"]
|
||||
|
||||
_bins_per_dim: tuple[int] | None = None
|
||||
|
||||
@property
|
||||
def bins_per_dim(self):
|
||||
if self._bins_per_dim is not None:
|
||||
return self._bins_per_dim
|
||||
|
||||
if self.codebook_type == "fsq":
|
||||
return self._get_bins_fsq(self.target_codebook_size)
|
||||
elif self.codebook_type == "lfq": # noqa: RET505
|
||||
return self._get_bins_lfq(self.target_codebook_size)
|
||||
elif self.codebook_type == "custom":
|
||||
return self._get_bins_custom(self.target_codebook_size)
|
||||
else:
|
||||
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
|
||||
|
||||
@property
|
||||
def place_values(self):
|
||||
place_values = [1]
|
||||
for b in self.bins_per_dim[:-1]:
|
||||
place_values.append(place_values[-1] * b)
|
||||
return jnp.array(place_values)
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_fsq(target_codebook_size):
|
||||
"""
|
||||
Get bins per dimension based on codebook size, from the original FSQ paper.
|
||||
"""
|
||||
if target_codebook_size == 2**8:
|
||||
return (8, 6, 5)
|
||||
elif target_codebook_size == 2**10: # noqa: RET505
|
||||
return (8, 5, 5, 5)
|
||||
elif target_codebook_size == 2**12:
|
||||
return (7, 5, 5, 5, 5)
|
||||
elif target_codebook_size == 2**14:
|
||||
return (8, 8, 8, 6, 5)
|
||||
elif target_codebook_size == 2**16:
|
||||
return (8, 8, 8, 5, 5, 5)
|
||||
else:
|
||||
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_custom(target_codebook_size):
|
||||
if target_codebook_size == 2**8:
|
||||
return (16, 16)
|
||||
elif target_codebook_size == 2**10: # noqa: RET505
|
||||
return (32, 32)
|
||||
elif target_codebook_size == 2**12:
|
||||
return (64, 64)
|
||||
elif target_codebook_size == 2**14:
|
||||
return (128, 128)
|
||||
elif target_codebook_size == 2**16:
|
||||
return (256, 256)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_lfq(target_codebook_size):
|
||||
"""
|
||||
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
|
||||
"""
|
||||
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
|
||||
|
||||
return (2,) * int(math.log2(target_codebook_size))
|
||||
|
||||
def setup(self):
|
||||
self.proj_down = nn.Dense(len(self.bins_per_dim))
|
||||
self.proj_up = nn.Dense(self.input_dim)
|
||||
|
||||
def __call__(self, inputs):
|
||||
tokens, z = self.encode(inputs)
|
||||
output = self.decode(tokens, z_grad=z)
|
||||
return tokens, output
|
||||
|
||||
def encode(self, inputs):
|
||||
bases = jnp.array(self.bins_per_dim)
|
||||
|
||||
x = self.proj_down(inputs)
|
||||
z = jnp.tanh(x)
|
||||
|
||||
# Quantize
|
||||
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
|
||||
tokens = self.undigitize(digits)
|
||||
|
||||
return tokens, z
|
||||
|
||||
def decode(self, tokens, z_grad: jax.Array | None = None):
|
||||
bases = jnp.array(self.bins_per_dim)
|
||||
digits = self.digitize(tokens)
|
||||
|
||||
z_q = digits / (bases - 1) * 2 - 1
|
||||
|
||||
if z_grad is not None:
|
||||
chex.assert_equal_shape([z_q, z_grad])
|
||||
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
|
||||
|
||||
return self.proj_up(z_q)
|
||||
|
||||
def undigitize(self, digits):
|
||||
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
|
||||
|
||||
def digitize(self, tokens):
|
||||
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return math.prod(self.bins_per_dim)
|
||||
|
||||
|
||||
class ResNetDownBlock(nn.Module):
|
||||
stride: int = 1
|
||||
n_filters: int = 64
|
||||
dropout_rate: float = 0.0
|
||||
group_size: int = 32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, *, train=True):
|
||||
skip = x
|
||||
|
||||
if self.stride > 1 or x.shape[-1] != self.n_filters:
|
||||
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||
|
||||
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = nn.relu(x)
|
||||
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
|
||||
|
||||
return skip + x
|
||||
|
||||
|
||||
class ResNetUpBlock(nn.Module):
|
||||
stride: int = 1
|
||||
n_filters: int = 64
|
||||
dropout_rate: float = 0.0
|
||||
group_size: int = 32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, *, train=True):
|
||||
skip = x
|
||||
|
||||
if self.stride > 1:
|
||||
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||
|
||||
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = nn.relu(x)
|
||||
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
|
||||
|
||||
return skip + x
|
||||
|
||||
|
||||
@dataclass
|
||||
class LfqCodebookOutput:
|
||||
tokens: jnp.ndarray
|
||||
z: jnp.ndarray
|
||||
z_q: jnp.ndarray
|
||||
token_log_probs: jnp.ndarray
|
||||
commit_loss: jnp.ndarray
|
||||
|
||||
|
||||
class LookupFreeQuantization(nn.Module):
|
||||
num_dims: int
|
||||
latent_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.codebook = jnp.array([-1, 1])
|
||||
# self.activation = lambda x: x
|
||||
self.activation = nn.tanh
|
||||
|
||||
self.project_down = nn.Dense(self.num_dims)
|
||||
self.project_up = nn.Dense(self.latent_dim)
|
||||
|
||||
def encode(self, z):
|
||||
z = self.project_down(z)
|
||||
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||
token_bits = jnp.argmin(token_squared_distances, axis=-1)
|
||||
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
|
||||
|
||||
def decode(self, tokens):
|
||||
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
|
||||
return self.project_up(self.codebook[token_bits])
|
||||
|
||||
def loss(self, x):
|
||||
z = self.project_down(x)
|
||||
z = self.activation(z)
|
||||
|
||||
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||
tokens = jnp.argmin(token_squared_distances, axis=-1)
|
||||
|
||||
token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
|
||||
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
|
||||
token_bit_expansions = jnp.bitwise_and(
|
||||
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
|
||||
).astype(jnp.int32)
|
||||
token_log_probs = (
|
||||
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
|
||||
+ token_bit_log_probs[..., 1] @ token_bit_expansions
|
||||
) # (batch_size, num_tokens, 2 ** num_dims)
|
||||
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
|
||||
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
|
||||
|
||||
z_q = self.codebook[tokens]
|
||||
commit_loss = jnp.square(z - z_q).mean()
|
||||
z_q = jax.lax.stop_gradient(z_q - z) + z
|
||||
|
||||
z_q = self.project_up(z_q)
|
||||
z = self.project_up(z)
|
||||
|
||||
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
|
||||
return LfqCodebookOutput(
|
||||
tokens=tokens,
|
||||
z=z,
|
||||
z_q=z_q,
|
||||
token_log_probs=jnp.zeros(()),
|
||||
commit_loss=commit_loss,
|
||||
)
|
||||
|
||||
|
||||
def make_block_causal_attention_matrix(q, k, bs_q, bs_k):
|
||||
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
|
||||
|
||||
|
||||
class GeGLU(Module):
|
||||
"""Gated Linear Unit with GELU (GeGLU) activation function.
|
||||
GeGLU is a Flax layer that combines a linear transformation with a GELU
|
||||
activation function in a gating mechanism. It is often used in Transformer models
|
||||
to provide non-linear capabilities while preserving a strong linear component.
|
||||
Example usage::
|
||||
>>> import flax.linen as nn
|
||||
>>> class TransformerBlock(nn.Module):
|
||||
... @nn.compact
|
||||
... def __call__(self, x):
|
||||
... x = nn.Dense(2)(x)
|
||||
... x = nn.GeGLU()(x) # initialized
|
||||
... return x
|
||||
Attributes:
|
||||
features: the number of output features (default: None).
|
||||
"""
|
||||
|
||||
output_dim: int = -1
|
||||
|
||||
@compact
|
||||
def __call__(self, inputs: Array) -> Array:
|
||||
"""Applies the GeGLU activation to the inputs.
|
||||
Args:
|
||||
inputs: the nd-array to apply the GeGLU activation function to.
|
||||
Returns:
|
||||
The transformed input.
|
||||
"""
|
||||
if self.output_dim == -1:
|
||||
output_dim = inputs.shape[-1]
|
||||
else:
|
||||
output_dim = self.output_dim
|
||||
|
||||
x = nn.Dense(output_dim * 2)(inputs)
|
||||
x, gate = x[..., :output_dim], x[..., output_dim:]
|
||||
return x * nn.gelu(gate)
|
||||
|
||||
|
||||
class CrossAttentionLayer(nn.Module):
|
||||
dropout_rate: float = 0.0
|
||||
num_heads: int = None
|
||||
causal: bool = False
|
||||
mlp_ratio: float = 4.0
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
|
||||
d_embed = x.shape[-1]
|
||||
seq_len_q = x.shape[-2]
|
||||
seq_len_k = y.shape[-2]
|
||||
|
||||
if self.causal:
|
||||
# One block size will be 1
|
||||
bs_q = max(seq_len_q // seq_len_k, 1)
|
||||
bs_k = max(seq_len_k // seq_len_q, 1)
|
||||
|
||||
mask_self = nn.make_causal_mask(x[..., 0])
|
||||
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
|
||||
|
||||
# Self-attention block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads or d_embed // 64,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=not train,
|
||||
)(x, x, x, mask=mask_self)
|
||||
x = skip + x
|
||||
|
||||
# Cross-attention block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
# bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads or d_embed // 64,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=not train,
|
||||
# attention_fn=partial(nn.dot_product_attention, bias=bias),
|
||||
)(x, y, y, mask=mask_cross)
|
||||
x = skip + x
|
||||
|
||||
# MLP block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = GeGLU()(x)
|
||||
x = nn.Dense(d_embed)(x)
|
||||
return skip + x
|
||||
|
||||
|
||||
def sinusoidal_pe_init(_, shape):
|
||||
seq_len, d_embed = shape
|
||||
|
||||
position = jnp.arange(0, seq_len, 1)
|
||||
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
|
||||
return jnp.concatenate(
|
||||
[
|
||||
jnp.sin(position[:, jnp.newaxis] * div_term),
|
||||
jnp.cos(position[:, jnp.newaxis] * div_term),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
|
||||
class TokenizerEncoderDecoder(nn.Module):
|
||||
num_tokens: int
|
||||
num_cross_tokens: int
|
||||
num_layers: int
|
||||
causal: bool
|
||||
|
||||
mlp_ratio: float = 4.0
|
||||
use_state_conditioning: bool = False
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, y, *, train=True, state_conditioning=None, mask=None):
|
||||
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
|
||||
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
|
||||
|
||||
if mask is not None:
|
||||
# mask is (batch_dims..., num_cross_tokens)
|
||||
chex.assert_equal_shape([y[..., 0], mask])
|
||||
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
|
||||
else:
|
||||
attn_mask = jnp.ones(y.shape[:-2] + (1, self.num_tokens, self.num_cross_tokens))
|
||||
|
||||
if self.use_state_conditioning:
|
||||
assert state_conditioning is not None, "State conditioning is required for this model."
|
||||
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
|
||||
y = jnp.concatenate([y, state_embed], axis=-2)
|
||||
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
|
||||
|
||||
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
|
||||
x, y, train=train, mask_self=None, mask_cross=attn_mask
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FsqAttentionTokenizer(nn.Module):
|
||||
embed_dim: int
|
||||
data_dim: int
|
||||
data_horizon: int
|
||||
num_tokens: int
|
||||
num_layers: int
|
||||
target_codebook_size: int
|
||||
causal: bool = False
|
||||
mlp_ratio: float = 2.0
|
||||
|
||||
bound: float | None = None
|
||||
|
||||
use_state_conditioning: bool = False
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size))
|
||||
|
||||
def setup(self):
|
||||
self.proj = nn.Dense(self.embed_dim)
|
||||
self.encoder = TokenizerEncoderDecoder(
|
||||
num_tokens=self.num_tokens,
|
||||
num_cross_tokens=self.data_horizon,
|
||||
num_layers=self.num_layers,
|
||||
causal=self.causal,
|
||||
use_state_conditioning=self.use_state_conditioning,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
)
|
||||
self.codebook = FsqCodebook(
|
||||
input_dim=self.embed_dim,
|
||||
target_codebook_size=self.target_codebook_size,
|
||||
codebook_type="custom",
|
||||
)
|
||||
self.decoder = TokenizerEncoderDecoder(
|
||||
num_tokens=self.data_horizon,
|
||||
num_cross_tokens=self.num_tokens,
|
||||
num_layers=self.num_layers,
|
||||
causal=self.causal,
|
||||
use_state_conditioning=self.use_state_conditioning,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
)
|
||||
|
||||
self.proj_mean = nn.Dense(self.data_dim)
|
||||
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
|
||||
|
||||
def tokenize(self, action, *, obs=None, train=False):
|
||||
if self.bound is not None:
|
||||
action = jnp.clip(action, -self.bound, self.bound)
|
||||
|
||||
x = self.proj(action)
|
||||
x = self.encoder(x, train=train, state_conditioning=obs)
|
||||
|
||||
return self.codebook.encode(x)
|
||||
|
||||
def detokenize(self, tokens, *, obs=None):
|
||||
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
|
||||
mean = self.proj_mean(x)
|
||||
return mean * self.out_scale
|
||||
|
||||
def loss(self, action, *, obs=None, train=True):
|
||||
# Encode
|
||||
x = self.proj(action)
|
||||
z = self.encoder(x, train=train, state_conditioning=obs)
|
||||
|
||||
# Quantize
|
||||
tokens, z = self.codebook(z)
|
||||
|
||||
# Decode
|
||||
x = self.decoder(z, train=train, state_conditioning=obs)
|
||||
mean = self.proj_mean(x) * self.out_scale
|
||||
|
||||
mse = jnp.mean(jnp.square(action - mean))
|
||||
mae = jnp.mean(jnp.abs(action - mean))
|
||||
|
||||
return mse, {
|
||||
"mse": mse,
|
||||
"mae": mae,
|
||||
}
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Dummy for .init
|
||||
"""
|
||||
return self.loss(*args, **kwargs)
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
@@ -125,3 +126,215 @@ class FASTTokenizer:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
class BinningTokenizer:
|
||||
def __init__(self, max_len: int = 256, n_bins: int = 256):
|
||||
self._max_len = max_len
|
||||
self._n_bins = n_bins
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
if len(action_tokens) < action_horizon * action_dim:
|
||||
return np.zeros([action_horizon, action_dim], dtype=np.float32)
|
||||
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
|
||||
return action_tokens / self._n_bins * 2 - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
class FSQTokenizer:
|
||||
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
|
||||
import jax
|
||||
import orbax.checkpoint as ocp
|
||||
|
||||
import openpi.models.fsq_tokenizer_v2 as fsq_tokenizer
|
||||
|
||||
self._max_len = max_len
|
||||
|
||||
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
|
||||
# Download tokenizer
|
||||
path = download.maybe_download(fsq_tokenizer_path)
|
||||
tok_path = os.path.join(path, os.listdir(path)[0])
|
||||
|
||||
# Split step from path
|
||||
step = int(tok_path.split("/")[-1])
|
||||
base_path = tok_path.rsplit("/", 1)[0]
|
||||
|
||||
mgr = ocp.CheckpointManager(
|
||||
base_path,
|
||||
item_handlers={
|
||||
"params": ocp.StandardCheckpointHandler(),
|
||||
"opt_state": ocp.StandardCheckpointHandler(),
|
||||
"config": ocp.JsonCheckpointHandler(),
|
||||
},
|
||||
options=ocp.CheckpointManagerOptions(max_to_keep=1),
|
||||
)
|
||||
|
||||
try:
|
||||
restored = mgr.restore(
|
||||
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
|
||||
)
|
||||
config = restored["config"]
|
||||
self._params = restored["params"]
|
||||
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
|
||||
) from e
|
||||
|
||||
# Compile tokenize and detokenize functions
|
||||
self._tokenize_fn = jax.jit(
|
||||
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
|
||||
)
|
||||
self._detokenize_fn = jax.jit(
|
||||
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
|
||||
)
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
try:
|
||||
import jax
|
||||
|
||||
# Move computation to CPU and compile on-demand
|
||||
device = jax.devices("cpu")[0]
|
||||
with jax.default_device(device):
|
||||
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
|
||||
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
|
||||
except Exception as e:
|
||||
logging.warning(f"Error decoding FSQ: {e}")
|
||||
return np.zeros((action_horizon, action_dim))
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
@@ -102,6 +102,8 @@ class ModelTransformFactory(GroupFactory):
|
||||
|
||||
# If provided, will determine the default prompt that be used by the model.
|
||||
default_prompt: str | None = None
|
||||
fast_model_tokenizer: Any | None = None
|
||||
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
||||
match model_config.model_type:
|
||||
@@ -116,17 +118,21 @@ class ModelTransformFactory(GroupFactory):
|
||||
],
|
||||
)
|
||||
case _model.ModelType.PI0_FAST:
|
||||
tokenizer_cls = (
|
||||
_tokenizer.FASTTokenizer if self.fast_model_tokenizer is None else self.fast_model_tokenizer
|
||||
)
|
||||
tokenizer_kwargs = {} if self.fast_model_tokenizer_kwargs is None else self.fast_model_tokenizer_kwargs
|
||||
return _transforms.Group(
|
||||
inputs=[
|
||||
_transforms.InjectDefaultPrompt(self.default_prompt),
|
||||
_transforms.ResizeImages(224, 224),
|
||||
_transforms.TokenizeFASTInputs(
|
||||
_tokenizer.FASTTokenizer(model_config.max_token_len),
|
||||
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
_transforms.ExtractFASTActions(
|
||||
_tokenizer.FASTTokenizer(model_config.max_token_len),
|
||||
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
||||
action_horizon=model_config.action_horizon,
|
||||
action_dim=model_config.action_dim,
|
||||
)
|
||||
@@ -470,6 +476,78 @@ _CONFIGS = [
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="pi0_fast_droid_jointpos",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[
|
||||
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
|
||||
droid_policy.DroidOutputs(),
|
||||
],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_binning_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
model_transforms=ModelTransformFactory(
|
||||
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_fast_specialist_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
model_transforms=ModelTransformFactory(
|
||||
fast_model_tokenizer=_tokenizer.FASTTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_vq_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
model_transforms=ModelTransformFactory(
|
||||
fast_model_tokenizer=_tokenizer.FSQTokenizer,
|
||||
fast_model_tokenizer_kwargs={
|
||||
"fsq_tokenizer_path": "s3://openpi-assets-simeval/tokenizers/droid_fsq_tokenizer"
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
#
|
||||
# Fine-tuning Libero configs.
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user