add DROID policies
This commit is contained in:
466
src/openpi/models/fsq_tokenizer_v2.py
Normal file
466
src/openpi/models/fsq_tokenizer_v2.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
import math
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import chex
|
||||||
|
from einops import einops
|
||||||
|
from flax import linen as nn
|
||||||
|
from flax.linen.module import Module
|
||||||
|
from flax.linen.module import compact
|
||||||
|
from flax.struct import dataclass
|
||||||
|
from flax.typing import Array
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
class FsqCodebook(nn.Module):
|
||||||
|
input_dim: int
|
||||||
|
target_codebook_size: int
|
||||||
|
codebook_type: Literal["fsq", "lfq"]
|
||||||
|
|
||||||
|
_bins_per_dim: tuple[int] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bins_per_dim(self):
|
||||||
|
if self._bins_per_dim is not None:
|
||||||
|
return self._bins_per_dim
|
||||||
|
|
||||||
|
if self.codebook_type == "fsq":
|
||||||
|
return self._get_bins_fsq(self.target_codebook_size)
|
||||||
|
elif self.codebook_type == "lfq": # noqa: RET505
|
||||||
|
return self._get_bins_lfq(self.target_codebook_size)
|
||||||
|
elif self.codebook_type == "custom":
|
||||||
|
return self._get_bins_custom(self.target_codebook_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def place_values(self):
|
||||||
|
place_values = [1]
|
||||||
|
for b in self.bins_per_dim[:-1]:
|
||||||
|
place_values.append(place_values[-1] * b)
|
||||||
|
return jnp.array(place_values)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_bins_fsq(target_codebook_size):
|
||||||
|
"""
|
||||||
|
Get bins per dimension based on codebook size, from the original FSQ paper.
|
||||||
|
"""
|
||||||
|
if target_codebook_size == 2**8:
|
||||||
|
return (8, 6, 5)
|
||||||
|
elif target_codebook_size == 2**10: # noqa: RET505
|
||||||
|
return (8, 5, 5, 5)
|
||||||
|
elif target_codebook_size == 2**12:
|
||||||
|
return (7, 5, 5, 5, 5)
|
||||||
|
elif target_codebook_size == 2**14:
|
||||||
|
return (8, 8, 8, 6, 5)
|
||||||
|
elif target_codebook_size == 2**16:
|
||||||
|
return (8, 8, 8, 5, 5, 5)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_bins_custom(target_codebook_size):
|
||||||
|
if target_codebook_size == 2**8:
|
||||||
|
return (16, 16)
|
||||||
|
elif target_codebook_size == 2**10: # noqa: RET505
|
||||||
|
return (32, 32)
|
||||||
|
elif target_codebook_size == 2**12:
|
||||||
|
return (64, 64)
|
||||||
|
elif target_codebook_size == 2**14:
|
||||||
|
return (128, 128)
|
||||||
|
elif target_codebook_size == 2**16:
|
||||||
|
return (256, 256)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_bins_lfq(target_codebook_size):
|
||||||
|
"""
|
||||||
|
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
|
||||||
|
"""
|
||||||
|
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
|
||||||
|
|
||||||
|
return (2,) * int(math.log2(target_codebook_size))
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.proj_down = nn.Dense(len(self.bins_per_dim))
|
||||||
|
self.proj_up = nn.Dense(self.input_dim)
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
tokens, z = self.encode(inputs)
|
||||||
|
output = self.decode(tokens, z_grad=z)
|
||||||
|
return tokens, output
|
||||||
|
|
||||||
|
def encode(self, inputs):
|
||||||
|
bases = jnp.array(self.bins_per_dim)
|
||||||
|
|
||||||
|
x = self.proj_down(inputs)
|
||||||
|
z = jnp.tanh(x)
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
|
||||||
|
tokens = self.undigitize(digits)
|
||||||
|
|
||||||
|
return tokens, z
|
||||||
|
|
||||||
|
def decode(self, tokens, z_grad: jax.Array | None = None):
|
||||||
|
bases = jnp.array(self.bins_per_dim)
|
||||||
|
digits = self.digitize(tokens)
|
||||||
|
|
||||||
|
z_q = digits / (bases - 1) * 2 - 1
|
||||||
|
|
||||||
|
if z_grad is not None:
|
||||||
|
chex.assert_equal_shape([z_q, z_grad])
|
||||||
|
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
|
||||||
|
|
||||||
|
return self.proj_up(z_q)
|
||||||
|
|
||||||
|
def undigitize(self, digits):
|
||||||
|
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
|
||||||
|
|
||||||
|
def digitize(self, tokens):
|
||||||
|
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return math.prod(self.bins_per_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetDownBlock(nn.Module):
|
||||||
|
stride: int = 1
|
||||||
|
n_filters: int = 64
|
||||||
|
dropout_rate: float = 0.0
|
||||||
|
group_size: int = 32
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x, *, train=True):
|
||||||
|
skip = x
|
||||||
|
|
||||||
|
if self.stride > 1 or x.shape[-1] != self.n_filters:
|
||||||
|
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||||
|
|
||||||
|
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||||
|
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||||
|
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||||
|
x = nn.relu(x)
|
||||||
|
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
|
||||||
|
|
||||||
|
return skip + x
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetUpBlock(nn.Module):
|
||||||
|
stride: int = 1
|
||||||
|
n_filters: int = 64
|
||||||
|
dropout_rate: float = 0.0
|
||||||
|
group_size: int = 32
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x, *, train=True):
|
||||||
|
skip = x
|
||||||
|
|
||||||
|
if self.stride > 1:
|
||||||
|
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||||
|
|
||||||
|
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||||
|
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||||
|
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||||
|
x = nn.relu(x)
|
||||||
|
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
|
||||||
|
|
||||||
|
return skip + x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LfqCodebookOutput:
|
||||||
|
tokens: jnp.ndarray
|
||||||
|
z: jnp.ndarray
|
||||||
|
z_q: jnp.ndarray
|
||||||
|
token_log_probs: jnp.ndarray
|
||||||
|
commit_loss: jnp.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class LookupFreeQuantization(nn.Module):
|
||||||
|
num_dims: int
|
||||||
|
latent_dim: int
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.codebook = jnp.array([-1, 1])
|
||||||
|
# self.activation = lambda x: x
|
||||||
|
self.activation = nn.tanh
|
||||||
|
|
||||||
|
self.project_down = nn.Dense(self.num_dims)
|
||||||
|
self.project_up = nn.Dense(self.latent_dim)
|
||||||
|
|
||||||
|
def encode(self, z):
|
||||||
|
z = self.project_down(z)
|
||||||
|
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||||
|
token_bits = jnp.argmin(token_squared_distances, axis=-1)
|
||||||
|
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
|
||||||
|
return self.project_up(self.codebook[token_bits])
|
||||||
|
|
||||||
|
def loss(self, x):
|
||||||
|
z = self.project_down(x)
|
||||||
|
z = self.activation(z)
|
||||||
|
|
||||||
|
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||||
|
tokens = jnp.argmin(token_squared_distances, axis=-1)
|
||||||
|
|
||||||
|
token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
|
||||||
|
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
|
||||||
|
token_bit_expansions = jnp.bitwise_and(
|
||||||
|
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
|
||||||
|
).astype(jnp.int32)
|
||||||
|
token_log_probs = (
|
||||||
|
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
|
||||||
|
+ token_bit_log_probs[..., 1] @ token_bit_expansions
|
||||||
|
) # (batch_size, num_tokens, 2 ** num_dims)
|
||||||
|
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
|
||||||
|
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
|
||||||
|
|
||||||
|
z_q = self.codebook[tokens]
|
||||||
|
commit_loss = jnp.square(z - z_q).mean()
|
||||||
|
z_q = jax.lax.stop_gradient(z_q - z) + z
|
||||||
|
|
||||||
|
z_q = self.project_up(z_q)
|
||||||
|
z = self.project_up(z)
|
||||||
|
|
||||||
|
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
|
||||||
|
return LfqCodebookOutput(
|
||||||
|
tokens=tokens,
|
||||||
|
z=z,
|
||||||
|
z_q=z_q,
|
||||||
|
token_log_probs=jnp.zeros(()),
|
||||||
|
commit_loss=commit_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_block_causal_attention_matrix(q, k, bs_q, bs_k):
|
||||||
|
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
|
||||||
|
|
||||||
|
|
||||||
|
class GeGLU(Module):
|
||||||
|
"""Gated Linear Unit with GELU (GeGLU) activation function.
|
||||||
|
GeGLU is a Flax layer that combines a linear transformation with a GELU
|
||||||
|
activation function in a gating mechanism. It is often used in Transformer models
|
||||||
|
to provide non-linear capabilities while preserving a strong linear component.
|
||||||
|
Example usage::
|
||||||
|
>>> import flax.linen as nn
|
||||||
|
>>> class TransformerBlock(nn.Module):
|
||||||
|
... @nn.compact
|
||||||
|
... def __call__(self, x):
|
||||||
|
... x = nn.Dense(2)(x)
|
||||||
|
... x = nn.GeGLU()(x) # initialized
|
||||||
|
... return x
|
||||||
|
Attributes:
|
||||||
|
features: the number of output features (default: None).
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_dim: int = -1
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, inputs: Array) -> Array:
|
||||||
|
"""Applies the GeGLU activation to the inputs.
|
||||||
|
Args:
|
||||||
|
inputs: the nd-array to apply the GeGLU activation function to.
|
||||||
|
Returns:
|
||||||
|
The transformed input.
|
||||||
|
"""
|
||||||
|
if self.output_dim == -1:
|
||||||
|
output_dim = inputs.shape[-1]
|
||||||
|
else:
|
||||||
|
output_dim = self.output_dim
|
||||||
|
|
||||||
|
x = nn.Dense(output_dim * 2)(inputs)
|
||||||
|
x, gate = x[..., :output_dim], x[..., output_dim:]
|
||||||
|
return x * nn.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionLayer(nn.Module):
|
||||||
|
dropout_rate: float = 0.0
|
||||||
|
num_heads: int = None
|
||||||
|
causal: bool = False
|
||||||
|
mlp_ratio: float = 4.0
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
|
||||||
|
d_embed = x.shape[-1]
|
||||||
|
seq_len_q = x.shape[-2]
|
||||||
|
seq_len_k = y.shape[-2]
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
# One block size will be 1
|
||||||
|
bs_q = max(seq_len_q // seq_len_k, 1)
|
||||||
|
bs_k = max(seq_len_k // seq_len_q, 1)
|
||||||
|
|
||||||
|
mask_self = nn.make_causal_mask(x[..., 0])
|
||||||
|
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
|
||||||
|
|
||||||
|
# Self-attention block
|
||||||
|
skip = x
|
||||||
|
x = nn.LayerNorm()(x)
|
||||||
|
x = nn.MultiHeadDotProductAttention(
|
||||||
|
num_heads=self.num_heads or d_embed // 64,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
deterministic=not train,
|
||||||
|
)(x, x, x, mask=mask_self)
|
||||||
|
x = skip + x
|
||||||
|
|
||||||
|
# Cross-attention block
|
||||||
|
skip = x
|
||||||
|
x = nn.LayerNorm()(x)
|
||||||
|
# bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
|
||||||
|
x = nn.MultiHeadDotProductAttention(
|
||||||
|
num_heads=self.num_heads or d_embed // 64,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
deterministic=not train,
|
||||||
|
# attention_fn=partial(nn.dot_product_attention, bias=bias),
|
||||||
|
)(x, y, y, mask=mask_cross)
|
||||||
|
x = skip + x
|
||||||
|
|
||||||
|
# MLP block
|
||||||
|
skip = x
|
||||||
|
x = nn.LayerNorm()(x)
|
||||||
|
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
|
||||||
|
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||||
|
x = GeGLU()(x)
|
||||||
|
x = nn.Dense(d_embed)(x)
|
||||||
|
return skip + x
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_pe_init(_, shape):
|
||||||
|
seq_len, d_embed = shape
|
||||||
|
|
||||||
|
position = jnp.arange(0, seq_len, 1)
|
||||||
|
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
|
||||||
|
return jnp.concatenate(
|
||||||
|
[
|
||||||
|
jnp.sin(position[:, jnp.newaxis] * div_term),
|
||||||
|
jnp.cos(position[:, jnp.newaxis] * div_term),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerEncoderDecoder(nn.Module):
|
||||||
|
num_tokens: int
|
||||||
|
num_cross_tokens: int
|
||||||
|
num_layers: int
|
||||||
|
causal: bool
|
||||||
|
|
||||||
|
mlp_ratio: float = 4.0
|
||||||
|
use_state_conditioning: bool = False
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, y, *, train=True, state_conditioning=None, mask=None):
|
||||||
|
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
|
||||||
|
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# mask is (batch_dims..., num_cross_tokens)
|
||||||
|
chex.assert_equal_shape([y[..., 0], mask])
|
||||||
|
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
|
||||||
|
else:
|
||||||
|
attn_mask = jnp.ones(y.shape[:-2] + (1, self.num_tokens, self.num_cross_tokens))
|
||||||
|
|
||||||
|
if self.use_state_conditioning:
|
||||||
|
assert state_conditioning is not None, "State conditioning is required for this model."
|
||||||
|
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
|
||||||
|
y = jnp.concatenate([y, state_embed], axis=-2)
|
||||||
|
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
|
||||||
|
|
||||||
|
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
|
||||||
|
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
|
||||||
|
x, y, train=train, mask_self=None, mask_cross=attn_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FsqAttentionTokenizer(nn.Module):
|
||||||
|
embed_dim: int
|
||||||
|
data_dim: int
|
||||||
|
data_horizon: int
|
||||||
|
num_tokens: int
|
||||||
|
num_layers: int
|
||||||
|
target_codebook_size: int
|
||||||
|
causal: bool = False
|
||||||
|
mlp_ratio: float = 2.0
|
||||||
|
|
||||||
|
bound: float | None = None
|
||||||
|
|
||||||
|
use_state_conditioning: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size))
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.proj = nn.Dense(self.embed_dim)
|
||||||
|
self.encoder = TokenizerEncoderDecoder(
|
||||||
|
num_tokens=self.num_tokens,
|
||||||
|
num_cross_tokens=self.data_horizon,
|
||||||
|
num_layers=self.num_layers,
|
||||||
|
causal=self.causal,
|
||||||
|
use_state_conditioning=self.use_state_conditioning,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
)
|
||||||
|
self.codebook = FsqCodebook(
|
||||||
|
input_dim=self.embed_dim,
|
||||||
|
target_codebook_size=self.target_codebook_size,
|
||||||
|
codebook_type="custom",
|
||||||
|
)
|
||||||
|
self.decoder = TokenizerEncoderDecoder(
|
||||||
|
num_tokens=self.data_horizon,
|
||||||
|
num_cross_tokens=self.num_tokens,
|
||||||
|
num_layers=self.num_layers,
|
||||||
|
causal=self.causal,
|
||||||
|
use_state_conditioning=self.use_state_conditioning,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj_mean = nn.Dense(self.data_dim)
|
||||||
|
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
|
||||||
|
|
||||||
|
def tokenize(self, action, *, obs=None, train=False):
|
||||||
|
if self.bound is not None:
|
||||||
|
action = jnp.clip(action, -self.bound, self.bound)
|
||||||
|
|
||||||
|
x = self.proj(action)
|
||||||
|
x = self.encoder(x, train=train, state_conditioning=obs)
|
||||||
|
|
||||||
|
return self.codebook.encode(x)
|
||||||
|
|
||||||
|
def detokenize(self, tokens, *, obs=None):
|
||||||
|
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
|
||||||
|
mean = self.proj_mean(x)
|
||||||
|
return mean * self.out_scale
|
||||||
|
|
||||||
|
def loss(self, action, *, obs=None, train=True):
|
||||||
|
# Encode
|
||||||
|
x = self.proj(action)
|
||||||
|
z = self.encoder(x, train=train, state_conditioning=obs)
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
tokens, z = self.codebook(z)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
x = self.decoder(z, train=train, state_conditioning=obs)
|
||||||
|
mean = self.proj_mean(x) * self.out_scale
|
||||||
|
|
||||||
|
mse = jnp.mean(jnp.square(action - mean))
|
||||||
|
mae = jnp.mean(jnp.abs(action - mean))
|
||||||
|
|
||||||
|
return mse, {
|
||||||
|
"mse": mse,
|
||||||
|
"mae": mae,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Dummy for .init
|
||||||
|
"""
|
||||||
|
return self.loss(*args, **kwargs)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
@@ -125,3 +126,215 @@ class FASTTokenizer:
|
|||||||
if isinstance(tokens, list):
|
if isinstance(tokens, list):
|
||||||
tokens = np.array(tokens)
|
tokens = np.array(tokens)
|
||||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||||
|
|
||||||
|
|
||||||
|
class BinningTokenizer:
|
||||||
|
def __init__(self, max_len: int = 256, n_bins: int = 256):
|
||||||
|
self._max_len = max_len
|
||||||
|
self._n_bins = n_bins
|
||||||
|
|
||||||
|
# Download base PaliGemma tokenizer
|
||||||
|
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||||
|
with path.open("rb") as f:
|
||||||
|
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||||
|
|
||||||
|
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||||
|
|
||||||
|
def tokenize(
|
||||||
|
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||||
|
|
||||||
|
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||||
|
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||||
|
|
||||||
|
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||||
|
state_str = " ".join(map(str, discretized_state))
|
||||||
|
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||||
|
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||||
|
|
||||||
|
if actions is not None:
|
||||||
|
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
|
||||||
|
postfix_tokens = []
|
||||||
|
|
||||||
|
# Create output token sequence & masks
|
||||||
|
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||||
|
tokens = prefix_tokens + postfix_tokens
|
||||||
|
token_mask = [True] * len(tokens)
|
||||||
|
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||||
|
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||||
|
|
||||||
|
# Pad tokens to max length
|
||||||
|
tokens_len = len(tokens)
|
||||||
|
if tokens_len < self._max_len:
|
||||||
|
padding = [False] * (self._max_len - tokens_len)
|
||||||
|
tokens = tokens + padding
|
||||||
|
token_mask = token_mask + padding
|
||||||
|
ar_mask = ar_mask + padding
|
||||||
|
loss_mask = loss_mask + padding
|
||||||
|
else:
|
||||||
|
if len(tokens) > self._max_len:
|
||||||
|
logging.warning(
|
||||||
|
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||||
|
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||||
|
)
|
||||||
|
tokens = tokens[: self._max_len]
|
||||||
|
token_mask = token_mask[: self._max_len]
|
||||||
|
ar_mask = ar_mask[: self._max_len]
|
||||||
|
loss_mask = loss_mask[: self._max_len]
|
||||||
|
|
||||||
|
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||||
|
|
||||||
|
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||||
|
# Decode predicted output tokens
|
||||||
|
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||||
|
|
||||||
|
# Extract actions from FAST model outputs
|
||||||
|
if "Action: " not in decoded_tokens:
|
||||||
|
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||||
|
|
||||||
|
# Extract actions from decoded tokens
|
||||||
|
raw_action_tokens = np.array(
|
||||||
|
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||||
|
)
|
||||||
|
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||||
|
if len(action_tokens) < action_horizon * action_dim:
|
||||||
|
return np.zeros([action_horizon, action_dim], dtype=np.float32)
|
||||||
|
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
|
||||||
|
return action_tokens / self._n_bins * 2 - 1
|
||||||
|
|
||||||
|
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||||
|
if isinstance(tokens, list):
|
||||||
|
tokens = np.array(tokens)
|
||||||
|
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||||
|
|
||||||
|
|
||||||
|
class FSQTokenizer:
|
||||||
|
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
|
||||||
|
import jax
|
||||||
|
import orbax.checkpoint as ocp
|
||||||
|
|
||||||
|
import openpi.models.fsq_tokenizer_v2 as fsq_tokenizer
|
||||||
|
|
||||||
|
self._max_len = max_len
|
||||||
|
|
||||||
|
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
|
||||||
|
# Download tokenizer
|
||||||
|
path = download.maybe_download(fsq_tokenizer_path)
|
||||||
|
tok_path = os.path.join(path, os.listdir(path)[0])
|
||||||
|
|
||||||
|
# Split step from path
|
||||||
|
step = int(tok_path.split("/")[-1])
|
||||||
|
base_path = tok_path.rsplit("/", 1)[0]
|
||||||
|
|
||||||
|
mgr = ocp.CheckpointManager(
|
||||||
|
base_path,
|
||||||
|
item_handlers={
|
||||||
|
"params": ocp.StandardCheckpointHandler(),
|
||||||
|
"opt_state": ocp.StandardCheckpointHandler(),
|
||||||
|
"config": ocp.JsonCheckpointHandler(),
|
||||||
|
},
|
||||||
|
options=ocp.CheckpointManagerOptions(max_to_keep=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
restored = mgr.restore(
|
||||||
|
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
|
||||||
|
)
|
||||||
|
config = restored["config"]
|
||||||
|
self._params = restored["params"]
|
||||||
|
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Compile tokenize and detokenize functions
|
||||||
|
self._tokenize_fn = jax.jit(
|
||||||
|
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
|
||||||
|
)
|
||||||
|
self._detokenize_fn = jax.jit(
|
||||||
|
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download base PaliGemma tokenizer
|
||||||
|
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||||
|
with path.open("rb") as f:
|
||||||
|
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||||
|
|
||||||
|
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||||
|
|
||||||
|
def tokenize(
|
||||||
|
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||||
|
|
||||||
|
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||||
|
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||||
|
|
||||||
|
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||||
|
state_str = " ".join(map(str, discretized_state))
|
||||||
|
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||||
|
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||||
|
|
||||||
|
if actions is not None:
|
||||||
|
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
|
||||||
|
postfix_tokens = []
|
||||||
|
|
||||||
|
# Create output token sequence & masks
|
||||||
|
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||||
|
tokens = prefix_tokens + postfix_tokens
|
||||||
|
token_mask = [True] * len(tokens)
|
||||||
|
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||||
|
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||||
|
|
||||||
|
# Pad tokens to max length
|
||||||
|
tokens_len = len(tokens)
|
||||||
|
if tokens_len < self._max_len:
|
||||||
|
padding = [False] * (self._max_len - tokens_len)
|
||||||
|
tokens = tokens + padding
|
||||||
|
token_mask = token_mask + padding
|
||||||
|
ar_mask = ar_mask + padding
|
||||||
|
loss_mask = loss_mask + padding
|
||||||
|
else:
|
||||||
|
if len(tokens) > self._max_len:
|
||||||
|
logging.warning(
|
||||||
|
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||||
|
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||||
|
)
|
||||||
|
tokens = tokens[: self._max_len]
|
||||||
|
token_mask = token_mask[: self._max_len]
|
||||||
|
ar_mask = ar_mask[: self._max_len]
|
||||||
|
loss_mask = loss_mask[: self._max_len]
|
||||||
|
|
||||||
|
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||||
|
|
||||||
|
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||||
|
# Decode predicted output tokens
|
||||||
|
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||||
|
|
||||||
|
# Extract actions from FAST model outputs
|
||||||
|
if "Action: " not in decoded_tokens:
|
||||||
|
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||||
|
|
||||||
|
# Extract actions from decoded tokens
|
||||||
|
raw_action_tokens = np.array(
|
||||||
|
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||||
|
)
|
||||||
|
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
|
||||||
|
# Move computation to CPU and compile on-demand
|
||||||
|
device = jax.devices("cpu")[0]
|
||||||
|
with jax.default_device(device):
|
||||||
|
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
|
||||||
|
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Error decoding FSQ: {e}")
|
||||||
|
return np.zeros((action_horizon, action_dim))
|
||||||
|
|
||||||
|
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||||
|
if isinstance(tokens, list):
|
||||||
|
tokens = np.array(tokens)
|
||||||
|
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||||
|
|||||||
@@ -102,6 +102,8 @@ class ModelTransformFactory(GroupFactory):
|
|||||||
|
|
||||||
# If provided, will determine the default prompt that be used by the model.
|
# If provided, will determine the default prompt that be used by the model.
|
||||||
default_prompt: str | None = None
|
default_prompt: str | None = None
|
||||||
|
fast_model_tokenizer: Any | None = None
|
||||||
|
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
||||||
match model_config.model_type:
|
match model_config.model_type:
|
||||||
@@ -116,17 +118,21 @@ class ModelTransformFactory(GroupFactory):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
case _model.ModelType.PI0_FAST:
|
case _model.ModelType.PI0_FAST:
|
||||||
|
tokenizer_cls = (
|
||||||
|
_tokenizer.FASTTokenizer if self.fast_model_tokenizer is None else self.fast_model_tokenizer
|
||||||
|
)
|
||||||
|
tokenizer_kwargs = {} if self.fast_model_tokenizer_kwargs is None else self.fast_model_tokenizer_kwargs
|
||||||
return _transforms.Group(
|
return _transforms.Group(
|
||||||
inputs=[
|
inputs=[
|
||||||
_transforms.InjectDefaultPrompt(self.default_prompt),
|
_transforms.InjectDefaultPrompt(self.default_prompt),
|
||||||
_transforms.ResizeImages(224, 224),
|
_transforms.ResizeImages(224, 224),
|
||||||
_transforms.TokenizeFASTInputs(
|
_transforms.TokenizeFASTInputs(
|
||||||
_tokenizer.FASTTokenizer(model_config.max_token_len),
|
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
_transforms.ExtractFASTActions(
|
_transforms.ExtractFASTActions(
|
||||||
_tokenizer.FASTTokenizer(model_config.max_token_len),
|
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
||||||
action_horizon=model_config.action_horizon,
|
action_horizon=model_config.action_horizon,
|
||||||
action_dim=model_config.action_dim,
|
action_dim=model_config.action_dim,
|
||||||
)
|
)
|
||||||
@@ -470,6 +476,78 @@ _CONFIGS = [
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
TrainConfig(
|
||||||
|
name="pi0_fast_droid_jointpos",
|
||||||
|
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),
|
||||||
|
data=SimpleDataConfig(
|
||||||
|
assets=AssetsConfig(asset_id="droid"),
|
||||||
|
data_transforms=lambda model: _transforms.Group(
|
||||||
|
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||||
|
outputs=[
|
||||||
|
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
|
||||||
|
droid_policy.DroidOutputs(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
base_config=DataConfig(
|
||||||
|
prompt_from_task=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TrainConfig(
|
||||||
|
name="paligemma_binning_droid",
|
||||||
|
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
|
||||||
|
data=SimpleDataConfig(
|
||||||
|
assets=AssetsConfig(asset_id="droid"),
|
||||||
|
data_transforms=lambda model: _transforms.Group(
|
||||||
|
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||||
|
outputs=[droid_policy.DroidOutputs()],
|
||||||
|
),
|
||||||
|
base_config=DataConfig(
|
||||||
|
prompt_from_task=True,
|
||||||
|
),
|
||||||
|
model_transforms=ModelTransformFactory(
|
||||||
|
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TrainConfig(
|
||||||
|
name="paligemma_fast_specialist_droid",
|
||||||
|
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||||
|
data=SimpleDataConfig(
|
||||||
|
assets=AssetsConfig(asset_id="droid"),
|
||||||
|
data_transforms=lambda model: _transforms.Group(
|
||||||
|
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||||
|
outputs=[droid_policy.DroidOutputs()],
|
||||||
|
),
|
||||||
|
base_config=DataConfig(
|
||||||
|
prompt_from_task=True,
|
||||||
|
),
|
||||||
|
model_transforms=ModelTransformFactory(
|
||||||
|
fast_model_tokenizer=_tokenizer.FASTTokenizer,
|
||||||
|
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TrainConfig(
|
||||||
|
name="paligemma_vq_droid",
|
||||||
|
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||||
|
data=SimpleDataConfig(
|
||||||
|
assets=AssetsConfig(asset_id="droid"),
|
||||||
|
data_transforms=lambda model: _transforms.Group(
|
||||||
|
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||||
|
outputs=[droid_policy.DroidOutputs()],
|
||||||
|
),
|
||||||
|
base_config=DataConfig(
|
||||||
|
prompt_from_task=True,
|
||||||
|
),
|
||||||
|
model_transforms=ModelTransformFactory(
|
||||||
|
fast_model_tokenizer=_tokenizer.FSQTokenizer,
|
||||||
|
fast_model_tokenizer_kwargs={
|
||||||
|
"fsq_tokenizer_path": "s3://openpi-assets-simeval/tokenizers/droid_fsq_tokenizer"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
#
|
#
|
||||||
# Fine-tuning Libero configs.
|
# Fine-tuning Libero configs.
|
||||||
#
|
#
|
||||||
|
|||||||
Reference in New Issue
Block a user