From 1ce9ffe1349db6d259f1fa374eddcfdf7fc91bca Mon Sep 17 00:00:00 2001 From: Karl Pertsch Date: Mon, 14 Apr 2025 18:42:57 +0000 Subject: [PATCH] add DROID policies --- src/openpi/models/fsq_tokenizer_v2.py | 466 ++++++++++++++++++++++++++ src/openpi/models/tokenizer.py | 213 ++++++++++++ src/openpi/training/config.py | 82 ++++- 3 files changed, 759 insertions(+), 2 deletions(-) create mode 100644 src/openpi/models/fsq_tokenizer_v2.py diff --git a/src/openpi/models/fsq_tokenizer_v2.py b/src/openpi/models/fsq_tokenizer_v2.py new file mode 100644 index 0000000..e032c53 --- /dev/null +++ b/src/openpi/models/fsq_tokenizer_v2.py @@ -0,0 +1,466 @@ +import math +from typing import Literal + +import chex +from einops import einops +from flax import linen as nn +from flax.linen.module import Module +from flax.linen.module import compact +from flax.struct import dataclass +from flax.typing import Array +import jax +import jax.numpy as jnp + + +class FsqCodebook(nn.Module): + input_dim: int + target_codebook_size: int + codebook_type: Literal["fsq", "lfq"] + + _bins_per_dim: tuple[int] | None = None + + @property + def bins_per_dim(self): + if self._bins_per_dim is not None: + return self._bins_per_dim + + if self.codebook_type == "fsq": + return self._get_bins_fsq(self.target_codebook_size) + elif self.codebook_type == "lfq": # noqa: RET505 + return self._get_bins_lfq(self.target_codebook_size) + elif self.codebook_type == "custom": + return self._get_bins_custom(self.target_codebook_size) + else: + raise ValueError(f"Codebook type {self.codebook_type} not supported.") + + @property + def place_values(self): + place_values = [1] + for b in self.bins_per_dim[:-1]: + place_values.append(place_values[-1] * b) + return jnp.array(place_values) + + @staticmethod + def _get_bins_fsq(target_codebook_size): + """ + Get bins per dimension based on codebook size, from the original FSQ paper. + """ + if target_codebook_size == 2**8: + return (8, 6, 5) + elif target_codebook_size == 2**10: # noqa: RET505 + return (8, 5, 5, 5) + elif target_codebook_size == 2**12: + return (7, 5, 5, 5, 5) + elif target_codebook_size == 2**14: + return (8, 8, 8, 6, 5) + elif target_codebook_size == 2**16: + return (8, 8, 8, 5, 5, 5) + else: + raise ValueError(f"Codebook size {target_codebook_size} not supported.") + + @staticmethod + def _get_bins_custom(target_codebook_size): + if target_codebook_size == 2**8: + return (16, 16) + elif target_codebook_size == 2**10: # noqa: RET505 + return (32, 32) + elif target_codebook_size == 2**12: + return (64, 64) + elif target_codebook_size == 2**14: + return (128, 128) + elif target_codebook_size == 2**16: + return (256, 256) + return None + + @staticmethod + def _get_bins_lfq(target_codebook_size): + """ + Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension) + """ + assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ" + + return (2,) * int(math.log2(target_codebook_size)) + + def setup(self): + self.proj_down = nn.Dense(len(self.bins_per_dim)) + self.proj_up = nn.Dense(self.input_dim) + + def __call__(self, inputs): + tokens, z = self.encode(inputs) + output = self.decode(tokens, z_grad=z) + return tokens, output + + def encode(self, inputs): + bases = jnp.array(self.bins_per_dim) + + x = self.proj_down(inputs) + z = jnp.tanh(x) + + # Quantize + digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32) + tokens = self.undigitize(digits) + + return tokens, z + + def decode(self, tokens, z_grad: jax.Array | None = None): + bases = jnp.array(self.bins_per_dim) + digits = self.digitize(tokens) + + z_q = digits / (bases - 1) * 2 - 1 + + if z_grad is not None: + chex.assert_equal_shape([z_q, z_grad]) + z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad + + return self.proj_up(z_q) + + def undigitize(self, digits): + return jnp.sum(digits * jnp.array(self.place_values), axis=-1) + + def digitize(self, tokens): + return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim) + + @property + def vocab_size(self): + return math.prod(self.bins_per_dim) + + +class ResNetDownBlock(nn.Module): + stride: int = 1 + n_filters: int = 64 + dropout_rate: float = 0.0 + group_size: int = 32 + + @nn.compact + def __call__(self, x, *, train=True): + skip = x + + if self.stride > 1 or x.shape[-1] != self.n_filters: + skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) + + x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x) + x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) + x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) + x = nn.relu(x) + x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x) + + return skip + x + + +class ResNetUpBlock(nn.Module): + stride: int = 1 + n_filters: int = 64 + dropout_rate: float = 0.0 + group_size: int = 32 + + @nn.compact + def __call__(self, x, *, train=True): + skip = x + + if self.stride > 1: + skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) + + x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x) + x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) + x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) + x = nn.relu(x) + x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x) + + return skip + x + + +@dataclass +class LfqCodebookOutput: + tokens: jnp.ndarray + z: jnp.ndarray + z_q: jnp.ndarray + token_log_probs: jnp.ndarray + commit_loss: jnp.ndarray + + +class LookupFreeQuantization(nn.Module): + num_dims: int + latent_dim: int + + def setup(self): + self.codebook = jnp.array([-1, 1]) + # self.activation = lambda x: x + self.activation = nn.tanh + + self.project_down = nn.Dense(self.num_dims) + self.project_up = nn.Dense(self.latent_dim) + + def encode(self, z): + z = self.project_down(z) + token_squared_distances = jnp.square(z[..., None] - self.codebook) + token_bits = jnp.argmin(token_squared_distances, axis=-1) + return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1) + + def decode(self, tokens): + token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32) + return self.project_up(self.codebook[token_bits]) + + def loss(self, x): + z = self.project_down(x) + z = self.activation(z) + + token_squared_distances = jnp.square(z[..., None] - self.codebook) + tokens = jnp.argmin(token_squared_distances, axis=-1) + + token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1) + # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs + token_bit_expansions = jnp.bitwise_and( + jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None] + ).astype(jnp.int32) + token_log_probs = ( + token_bit_log_probs[..., 0] @ (1 - token_bit_expansions) + + token_bit_log_probs[..., 1] @ token_bit_expansions + ) # (batch_size, num_tokens, 2 ** num_dims) + token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1)) + chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims)) + + z_q = self.codebook[tokens] + commit_loss = jnp.square(z - z_q).mean() + z_q = jax.lax.stop_gradient(z_q - z) + z + + z_q = self.project_up(z_q) + z = self.project_up(z) + + tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1) + return LfqCodebookOutput( + tokens=tokens, + z=z, + z_q=z_q, + token_log_probs=jnp.zeros(()), + commit_loss=commit_loss, + ) + + +def make_block_causal_attention_matrix(q, k, bs_q, bs_k): + return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q)) + + +class GeGLU(Module): + """Gated Linear Unit with GELU (GeGLU) activation function. + GeGLU is a Flax layer that combines a linear transformation with a GELU + activation function in a gating mechanism. It is often used in Transformer models + to provide non-linear capabilities while preserving a strong linear component. + Example usage:: + >>> import flax.linen as nn + >>> class TransformerBlock(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(2)(x) + ... x = nn.GeGLU()(x) # initialized + ... return x + Attributes: + features: the number of output features (default: None). + """ + + output_dim: int = -1 + + @compact + def __call__(self, inputs: Array) -> Array: + """Applies the GeGLU activation to the inputs. + Args: + inputs: the nd-array to apply the GeGLU activation function to. + Returns: + The transformed input. + """ + if self.output_dim == -1: + output_dim = inputs.shape[-1] + else: + output_dim = self.output_dim + + x = nn.Dense(output_dim * 2)(inputs) + x, gate = x[..., :output_dim], x[..., output_dim:] + return x * nn.gelu(gate) + + +class CrossAttentionLayer(nn.Module): + dropout_rate: float = 0.0 + num_heads: int = None + causal: bool = False + mlp_ratio: float = 4.0 + + @nn.compact + def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True): + d_embed = x.shape[-1] + seq_len_q = x.shape[-2] + seq_len_k = y.shape[-2] + + if self.causal: + # One block size will be 1 + bs_q = max(seq_len_q // seq_len_k, 1) + bs_k = max(seq_len_k // seq_len_q, 1) + + mask_self = nn.make_causal_mask(x[..., 0]) + mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k) + + # Self-attention block + skip = x + x = nn.LayerNorm()(x) + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads or d_embed // 64, + dropout_rate=self.dropout_rate, + deterministic=not train, + )(x, x, x, mask=mask_self) + x = skip + x + + # Cross-attention block + skip = x + x = nn.LayerNorm()(x) + # bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5 + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads or d_embed // 64, + dropout_rate=self.dropout_rate, + deterministic=not train, + # attention_fn=partial(nn.dot_product_attention, bias=bias), + )(x, y, y, mask=mask_cross) + x = skip + x + + # MLP block + skip = x + x = nn.LayerNorm()(x) + x = nn.Dense(int(d_embed * self.mlp_ratio))(x) + x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) + x = GeGLU()(x) + x = nn.Dense(d_embed)(x) + return skip + x + + +def sinusoidal_pe_init(_, shape): + seq_len, d_embed = shape + + position = jnp.arange(0, seq_len, 1) + div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed)) + return jnp.concatenate( + [ + jnp.sin(position[:, jnp.newaxis] * div_term), + jnp.cos(position[:, jnp.newaxis] * div_term), + ], + axis=-1, + ) + + +class TokenizerEncoderDecoder(nn.Module): + num_tokens: int + num_cross_tokens: int + num_layers: int + causal: bool + + mlp_ratio: float = 4.0 + use_state_conditioning: bool = False + + @nn.compact + def __call__(self, y, *, train=True, state_conditioning=None, mask=None): + x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1])) + x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:]) + + if mask is not None: + # mask is (batch_dims..., num_cross_tokens) + chex.assert_equal_shape([y[..., 0], mask]) + attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens) + else: + attn_mask = jnp.ones(y.shape[:-2] + (1, self.num_tokens, self.num_cross_tokens)) + + if self.use_state_conditioning: + assert state_conditioning is not None, "State conditioning is required for this model." + state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :] + y = jnp.concatenate([y, state_embed], axis=-2) + attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1) + + y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:]) + + for _ in range(self.num_layers): + x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)( + x, y, train=train, mask_self=None, mask_cross=attn_mask + ) + + return x + + +class FsqAttentionTokenizer(nn.Module): + embed_dim: int + data_dim: int + data_horizon: int + num_tokens: int + num_layers: int + target_codebook_size: int + causal: bool = False + mlp_ratio: float = 2.0 + + bound: float | None = None + + use_state_conditioning: bool = False + + @property + def vocab_size(self): + return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) + + def setup(self): + self.proj = nn.Dense(self.embed_dim) + self.encoder = TokenizerEncoderDecoder( + num_tokens=self.num_tokens, + num_cross_tokens=self.data_horizon, + num_layers=self.num_layers, + causal=self.causal, + use_state_conditioning=self.use_state_conditioning, + mlp_ratio=self.mlp_ratio, + ) + self.codebook = FsqCodebook( + input_dim=self.embed_dim, + target_codebook_size=self.target_codebook_size, + codebook_type="custom", + ) + self.decoder = TokenizerEncoderDecoder( + num_tokens=self.data_horizon, + num_cross_tokens=self.num_tokens, + num_layers=self.num_layers, + causal=self.causal, + use_state_conditioning=self.use_state_conditioning, + mlp_ratio=self.mlp_ratio, + ) + + self.proj_mean = nn.Dense(self.data_dim) + self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0)) + + def tokenize(self, action, *, obs=None, train=False): + if self.bound is not None: + action = jnp.clip(action, -self.bound, self.bound) + + x = self.proj(action) + x = self.encoder(x, train=train, state_conditioning=obs) + + return self.codebook.encode(x) + + def detokenize(self, tokens, *, obs=None): + x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs) + mean = self.proj_mean(x) + return mean * self.out_scale + + def loss(self, action, *, obs=None, train=True): + # Encode + x = self.proj(action) + z = self.encoder(x, train=train, state_conditioning=obs) + + # Quantize + tokens, z = self.codebook(z) + + # Decode + x = self.decoder(z, train=train, state_conditioning=obs) + mean = self.proj_mean(x) * self.out_scale + + mse = jnp.mean(jnp.square(action - mean)) + mae = jnp.mean(jnp.abs(action - mean)) + + return mse, { + "mse": mse, + "mae": mae, + } + + def __call__(self, *args, **kwargs): + """ + Dummy for .init + """ + return self.loss(*args, **kwargs) diff --git a/src/openpi/models/tokenizer.py b/src/openpi/models/tokenizer.py index 29451c0..881dafe 100644 --- a/src/openpi/models/tokenizer.py +++ b/src/openpi/models/tokenizer.py @@ -1,4 +1,5 @@ import logging +import os import numpy as np import sentencepiece @@ -125,3 +126,215 @@ class FASTTokenizer: if isinstance(tokens, list): tokens = np.array(tokens) return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens + + +class BinningTokenizer: + def __init__(self, max_len: int = 256, n_bins: int = 256): + self._max_len = max_len + self._n_bins = n_bins + + # Download base PaliGemma tokenizer + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) + with path.open("rb") as f: + self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) + + self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens + + def tokenize( + self, prompt: str, state: np.ndarray, actions: np.ndarray | None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + cleaned_text = prompt.lower().strip().replace("_", " ") + + # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + # Convention: prefix includes prompt and string-representation of state, followed by ';' + state_str = " ".join(map(str, discretized_state)) + prefix = f"Task: {cleaned_text}, State: {state_str};\n" + prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) + + if actions is not None: + raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)") + postfix_tokens = [] + + # Create output token sequence & masks + # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) + tokens = prefix_tokens + postfix_tokens + token_mask = [True] * len(tokens) + ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) + loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only + + # Pad tokens to max length + tokens_len = len(tokens) + if tokens_len < self._max_len: + padding = [False] * (self._max_len - tokens_len) + tokens = tokens + padding + token_mask = token_mask + padding + ar_mask = ar_mask + padding + loss_mask = loss_mask + padding + else: + if len(tokens) > self._max_len: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " + "Consider increasing the `max_token_len` in your model config if this happens frequently." + ) + tokens = tokens[: self._max_len] + token_mask = token_mask[: self._max_len] + ar_mask = ar_mask[: self._max_len] + loss_mask = loss_mask[: self._max_len] + + return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) + + def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: + # Decode predicted output tokens + decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) + + # Extract actions from FAST model outputs + if "Action: " not in decoded_tokens: + return np.zeros((action_horizon, action_dim), dtype=np.float32) + + # Extract actions from decoded tokens + raw_action_tokens = np.array( + self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) + ) + action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) + if len(action_tokens) < action_horizon * action_dim: + return np.zeros([action_horizon, action_dim], dtype=np.float32) + action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim]) + return action_tokens / self._n_bins * 2 - 1 + + def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: + if isinstance(tokens, list): + tokens = np.array(tokens) + return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens + + +class FSQTokenizer: + def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None): + import jax + import orbax.checkpoint as ocp + + import openpi.models.fsq_tokenizer_v2 as fsq_tokenizer + + self._max_len = max_len + + assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided" + # Download tokenizer + path = download.maybe_download(fsq_tokenizer_path) + tok_path = os.path.join(path, os.listdir(path)[0]) + + # Split step from path + step = int(tok_path.split("/")[-1]) + base_path = tok_path.rsplit("/", 1)[0] + + mgr = ocp.CheckpointManager( + base_path, + item_handlers={ + "params": ocp.StandardCheckpointHandler(), + "opt_state": ocp.StandardCheckpointHandler(), + "config": ocp.JsonCheckpointHandler(), + }, + options=ocp.CheckpointManagerOptions(max_to_keep=1), + ) + + try: + restored = mgr.restore( + step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore()) + ) + config = restored["config"] + self._params = restored["params"] + self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config) + except Exception as e: + raise RuntimeError( + f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}" + ) from e + + # Compile tokenize and detokenize functions + self._tokenize_fn = jax.jit( + lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize) + ) + self._detokenize_fn = jax.jit( + lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize) + ) + + # Download base PaliGemma tokenizer + path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) + with path.open("rb") as f: + self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) + + self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens + + def tokenize( + self, prompt: str, state: np.ndarray, actions: np.ndarray | None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + cleaned_text = prompt.lower().strip().replace("_", " ") + + # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) + discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + # Convention: prefix includes prompt and string-representation of state, followed by ';' + state_str = " ".join(map(str, discretized_state)) + prefix = f"Task: {cleaned_text}, State: {state_str};\n" + prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) + + if actions is not None: + raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)") + postfix_tokens = [] + + # Create output token sequence & masks + # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) + tokens = prefix_tokens + postfix_tokens + token_mask = [True] * len(tokens) + ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) + loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only + + # Pad tokens to max length + tokens_len = len(tokens) + if tokens_len < self._max_len: + padding = [False] * (self._max_len - tokens_len) + tokens = tokens + padding + token_mask = token_mask + padding + ar_mask = ar_mask + padding + loss_mask = loss_mask + padding + else: + if len(tokens) > self._max_len: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " + "Consider increasing the `max_token_len` in your model config if this happens frequently." + ) + tokens = tokens[: self._max_len] + token_mask = token_mask[: self._max_len] + ar_mask = ar_mask[: self._max_len] + loss_mask = loss_mask[: self._max_len] + + return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) + + def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: + # Decode predicted output tokens + decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) + + # Extract actions from FAST model outputs + if "Action: " not in decoded_tokens: + return np.zeros((action_horizon, action_dim), dtype=np.float32) + + # Extract actions from decoded tokens + raw_action_tokens = np.array( + self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) + ) + action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) + try: + import jax + + # Move computation to CPU and compile on-demand + device = jax.devices("cpu")[0] + with jax.default_device(device): + detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0] + return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim]) + except Exception as e: + logging.warning(f"Error decoding FSQ: {e}") + return np.zeros((action_horizon, action_dim)) + + def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: + if isinstance(tokens, list): + tokens = np.array(tokens) + return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 164d59b..dd0fc04 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -102,6 +102,8 @@ class ModelTransformFactory(GroupFactory): # If provided, will determine the default prompt that be used by the model. default_prompt: str | None = None + fast_model_tokenizer: Any | None = None + fast_model_tokenizer_kwargs: dict[str, Any] | None = None def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group: match model_config.model_type: @@ -116,17 +118,21 @@ class ModelTransformFactory(GroupFactory): ], ) case _model.ModelType.PI0_FAST: + tokenizer_cls = ( + _tokenizer.FASTTokenizer if self.fast_model_tokenizer is None else self.fast_model_tokenizer + ) + tokenizer_kwargs = {} if self.fast_model_tokenizer_kwargs is None else self.fast_model_tokenizer_kwargs return _transforms.Group( inputs=[ _transforms.InjectDefaultPrompt(self.default_prompt), _transforms.ResizeImages(224, 224), _transforms.TokenizeFASTInputs( - _tokenizer.FASTTokenizer(model_config.max_token_len), + tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), ), ], outputs=[ _transforms.ExtractFASTActions( - _tokenizer.FASTTokenizer(model_config.max_token_len), + tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), action_horizon=model_config.action_horizon, action_dim=model_config.action_dim, ) @@ -470,6 +476,78 @@ _CONFIGS = [ ), ), ), + TrainConfig( + name="pi0_fast_droid_jointpos", + model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[ + _transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)), + droid_policy.DroidOutputs(), + ], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + ), + ), + TrainConfig( + name="paligemma_binning_droid", + model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + model_transforms=ModelTransformFactory( + fast_model_tokenizer=_tokenizer.BinningTokenizer, + ), + ), + ), + TrainConfig( + name="paligemma_fast_specialist_droid", + model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + model_transforms=ModelTransformFactory( + fast_model_tokenizer=_tokenizer.FASTTokenizer, + fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"}, + ), + ), + ), + TrainConfig( + name="paligemma_vq_droid", + model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15), + data=SimpleDataConfig( + assets=AssetsConfig(asset_id="droid"), + data_transforms=lambda model: _transforms.Group( + inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], + outputs=[droid_policy.DroidOutputs()], + ), + base_config=DataConfig( + prompt_from_task=True, + ), + model_transforms=ModelTransformFactory( + fast_model_tokenizer=_tokenizer.FSQTokenizer, + fast_model_tokenizer_kwargs={ + "fsq_tokenizer_path": "s3://openpi-assets-simeval/tokenizers/droid_fsq_tokenizer" + }, + ), + ), + ), # # Fine-tuning Libero configs. #