7 Commits

Author SHA1 Message Date
Karl Pertsch
b84cc75031 add binning jointpos 2025-04-25 05:28:23 +00:00
Karl Pertsch
c23bc86a0a load droid sim eval policies without credentials (#440)
small change to enable loading from the openpi sim eval bucket without credentials (for joint pos policies)
2025-04-17 15:39:53 -04:00
Arhan Jain
fe5d5580a4 load droid sim eval policies without credentials 2025-04-17 12:26:06 -07:00
Karl Pertsch
650b02e4ca add diffusion jointpos policy 2025-04-17 13:19:48 +00:00
Karl Pertsch
e43516e719 add diffusion droid policy 2025-04-14 20:15:23 +00:00
Karl Pertsch
20d63d47b7 additional policy 2025-04-14 19:18:09 +00:00
Karl Pertsch
1ce9ffe134 add DROID policies 2025-04-14 18:42:57 +00:00
4 changed files with 825 additions and 3 deletions

View File

@@ -0,0 +1,466 @@
import math
from typing import Literal
import chex
from einops import einops
from flax import linen as nn
from flax.linen.module import Module
from flax.linen.module import compact
from flax.struct import dataclass
from flax.typing import Array
import jax
import jax.numpy as jnp
class FsqCodebook(nn.Module):
input_dim: int
target_codebook_size: int
codebook_type: Literal["fsq", "lfq"]
_bins_per_dim: tuple[int] | None = None
@property
def bins_per_dim(self):
if self._bins_per_dim is not None:
return self._bins_per_dim
if self.codebook_type == "fsq":
return self._get_bins_fsq(self.target_codebook_size)
elif self.codebook_type == "lfq": # noqa: RET505
return self._get_bins_lfq(self.target_codebook_size)
elif self.codebook_type == "custom":
return self._get_bins_custom(self.target_codebook_size)
else:
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
@property
def place_values(self):
place_values = [1]
for b in self.bins_per_dim[:-1]:
place_values.append(place_values[-1] * b)
return jnp.array(place_values)
@staticmethod
def _get_bins_fsq(target_codebook_size):
"""
Get bins per dimension based on codebook size, from the original FSQ paper.
"""
if target_codebook_size == 2**8:
return (8, 6, 5)
elif target_codebook_size == 2**10: # noqa: RET505
return (8, 5, 5, 5)
elif target_codebook_size == 2**12:
return (7, 5, 5, 5, 5)
elif target_codebook_size == 2**14:
return (8, 8, 8, 6, 5)
elif target_codebook_size == 2**16:
return (8, 8, 8, 5, 5, 5)
else:
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
@staticmethod
def _get_bins_custom(target_codebook_size):
if target_codebook_size == 2**8:
return (16, 16)
elif target_codebook_size == 2**10: # noqa: RET505
return (32, 32)
elif target_codebook_size == 2**12:
return (64, 64)
elif target_codebook_size == 2**14:
return (128, 128)
elif target_codebook_size == 2**16:
return (256, 256)
return None
@staticmethod
def _get_bins_lfq(target_codebook_size):
"""
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
"""
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
return (2,) * int(math.log2(target_codebook_size))
def setup(self):
self.proj_down = nn.Dense(len(self.bins_per_dim))
self.proj_up = nn.Dense(self.input_dim)
def __call__(self, inputs):
tokens, z = self.encode(inputs)
output = self.decode(tokens, z_grad=z)
return tokens, output
def encode(self, inputs):
bases = jnp.array(self.bins_per_dim)
x = self.proj_down(inputs)
z = jnp.tanh(x)
# Quantize
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
tokens = self.undigitize(digits)
return tokens, z
def decode(self, tokens, z_grad: jax.Array | None = None):
bases = jnp.array(self.bins_per_dim)
digits = self.digitize(tokens)
z_q = digits / (bases - 1) * 2 - 1
if z_grad is not None:
chex.assert_equal_shape([z_q, z_grad])
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
return self.proj_up(z_q)
def undigitize(self, digits):
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
def digitize(self, tokens):
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
@property
def vocab_size(self):
return math.prod(self.bins_per_dim)
class ResNetDownBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x, *, train=True):
skip = x
if self.stride > 1 or x.shape[-1] != self.n_filters:
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
class ResNetUpBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x, *, train=True):
skip = x
if self.stride > 1:
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
@dataclass
class LfqCodebookOutput:
tokens: jnp.ndarray
z: jnp.ndarray
z_q: jnp.ndarray
token_log_probs: jnp.ndarray
commit_loss: jnp.ndarray
class LookupFreeQuantization(nn.Module):
num_dims: int
latent_dim: int
def setup(self):
self.codebook = jnp.array([-1, 1])
# self.activation = lambda x: x
self.activation = nn.tanh
self.project_down = nn.Dense(self.num_dims)
self.project_up = nn.Dense(self.latent_dim)
def encode(self, z):
z = self.project_down(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
token_bits = jnp.argmin(token_squared_distances, axis=-1)
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
def decode(self, tokens):
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
return self.project_up(self.codebook[token_bits])
def loss(self, x):
z = self.project_down(x)
z = self.activation(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
tokens = jnp.argmin(token_squared_distances, axis=-1)
token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
token_bit_expansions = jnp.bitwise_and(
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
).astype(jnp.int32)
token_log_probs = (
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
+ token_bit_log_probs[..., 1] @ token_bit_expansions
) # (batch_size, num_tokens, 2 ** num_dims)
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
z_q = self.codebook[tokens]
commit_loss = jnp.square(z - z_q).mean()
z_q = jax.lax.stop_gradient(z_q - z) + z
z_q = self.project_up(z_q)
z = self.project_up(z)
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
return LfqCodebookOutput(
tokens=tokens,
z=z,
z_q=z_q,
token_log_probs=jnp.zeros(()),
commit_loss=commit_loss,
)
def make_block_causal_attention_matrix(q, k, bs_q, bs_k):
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
class GeGLU(Module):
"""Gated Linear Unit with GELU (GeGLU) activation function.
GeGLU is a Flax layer that combines a linear transformation with a GELU
activation function in a gating mechanism. It is often used in Transformer models
to provide non-linear capabilities while preserving a strong linear component.
Example usage::
>>> import flax.linen as nn
>>> class TransformerBlock(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(2)(x)
... x = nn.GeGLU()(x) # initialized
... return x
Attributes:
features: the number of output features (default: None).
"""
output_dim: int = -1
@compact
def __call__(self, inputs: Array) -> Array:
"""Applies the GeGLU activation to the inputs.
Args:
inputs: the nd-array to apply the GeGLU activation function to.
Returns:
The transformed input.
"""
if self.output_dim == -1:
output_dim = inputs.shape[-1]
else:
output_dim = self.output_dim
x = nn.Dense(output_dim * 2)(inputs)
x, gate = x[..., :output_dim], x[..., output_dim:]
return x * nn.gelu(gate)
class CrossAttentionLayer(nn.Module):
dropout_rate: float = 0.0
num_heads: int = None
causal: bool = False
mlp_ratio: float = 4.0
@nn.compact
def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
d_embed = x.shape[-1]
seq_len_q = x.shape[-2]
seq_len_k = y.shape[-2]
if self.causal:
# One block size will be 1
bs_q = max(seq_len_q // seq_len_k, 1)
bs_k = max(seq_len_k // seq_len_q, 1)
mask_self = nn.make_causal_mask(x[..., 0])
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
# Self-attention block
skip = x
x = nn.LayerNorm()(x)
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
)(x, x, x, mask=mask_self)
x = skip + x
# Cross-attention block
skip = x
x = nn.LayerNorm()(x)
# bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
# attention_fn=partial(nn.dot_product_attention, bias=bias),
)(x, y, y, mask=mask_cross)
x = skip + x
# MLP block
skip = x
x = nn.LayerNorm()(x)
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = GeGLU()(x)
x = nn.Dense(d_embed)(x)
return skip + x
def sinusoidal_pe_init(_, shape):
seq_len, d_embed = shape
position = jnp.arange(0, seq_len, 1)
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
return jnp.concatenate(
[
jnp.sin(position[:, jnp.newaxis] * div_term),
jnp.cos(position[:, jnp.newaxis] * div_term),
],
axis=-1,
)
class TokenizerEncoderDecoder(nn.Module):
num_tokens: int
num_cross_tokens: int
num_layers: int
causal: bool
mlp_ratio: float = 4.0
use_state_conditioning: bool = False
@nn.compact
def __call__(self, y, *, train=True, state_conditioning=None, mask=None):
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
if mask is not None:
# mask is (batch_dims..., num_cross_tokens)
chex.assert_equal_shape([y[..., 0], mask])
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
else:
attn_mask = jnp.ones(y.shape[:-2] + (1, self.num_tokens, self.num_cross_tokens))
if self.use_state_conditioning:
assert state_conditioning is not None, "State conditioning is required for this model."
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
y = jnp.concatenate([y, state_embed], axis=-2)
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
for _ in range(self.num_layers):
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
x, y, train=train, mask_self=None, mask_cross=attn_mask
)
return x
class FsqAttentionTokenizer(nn.Module):
embed_dim: int
data_dim: int
data_horizon: int
num_tokens: int
num_layers: int
target_codebook_size: int
causal: bool = False
mlp_ratio: float = 2.0
bound: float | None = None
use_state_conditioning: bool = False
@property
def vocab_size(self):
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size))
def setup(self):
self.proj = nn.Dense(self.embed_dim)
self.encoder = TokenizerEncoderDecoder(
num_tokens=self.num_tokens,
num_cross_tokens=self.data_horizon,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.codebook = FsqCodebook(
input_dim=self.embed_dim,
target_codebook_size=self.target_codebook_size,
codebook_type="custom",
)
self.decoder = TokenizerEncoderDecoder(
num_tokens=self.data_horizon,
num_cross_tokens=self.num_tokens,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.proj_mean = nn.Dense(self.data_dim)
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
def tokenize(self, action, *, obs=None, train=False):
if self.bound is not None:
action = jnp.clip(action, -self.bound, self.bound)
x = self.proj(action)
x = self.encoder(x, train=train, state_conditioning=obs)
return self.codebook.encode(x)
def detokenize(self, tokens, *, obs=None):
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
mean = self.proj_mean(x)
return mean * self.out_scale
def loss(self, action, *, obs=None, train=True):
# Encode
x = self.proj(action)
z = self.encoder(x, train=train, state_conditioning=obs)
# Quantize
tokens, z = self.codebook(z)
# Decode
x = self.decoder(z, train=train, state_conditioning=obs)
mean = self.proj_mean(x) * self.out_scale
mse = jnp.mean(jnp.square(action - mean))
mae = jnp.mean(jnp.abs(action - mean))
return mse, {
"mse": mse,
"mae": mae,
}
def __call__(self, *args, **kwargs):
"""
Dummy for .init
"""
return self.loss(*args, **kwargs)

View File

@@ -1,4 +1,5 @@
import logging
import os
import numpy as np
import sentencepiece
@@ -125,3 +126,215 @@ class FASTTokenizer:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
class BinningTokenizer:
def __init__(self, max_len: int = 256, n_bins: int = 256):
self._max_len = max_len
self._n_bins = n_bins
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
if len(action_tokens) < action_horizon * action_dim:
return np.zeros([action_horizon, action_dim], dtype=np.float32)
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
return action_tokens / self._n_bins * 2 - 1
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
class FSQTokenizer:
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
import jax
import orbax.checkpoint as ocp
import openpi.models.fsq_tokenizer_v2 as fsq_tokenizer
self._max_len = max_len
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
# Download tokenizer
path = download.maybe_download(fsq_tokenizer_path)
tok_path = os.path.join(path, os.listdir(path)[0])
# Split step from path
step = int(tok_path.split("/")[-1])
base_path = tok_path.rsplit("/", 1)[0]
mgr = ocp.CheckpointManager(
base_path,
item_handlers={
"params": ocp.StandardCheckpointHandler(),
"opt_state": ocp.StandardCheckpointHandler(),
"config": ocp.JsonCheckpointHandler(),
},
options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
try:
restored = mgr.restore(
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
)
config = restored["config"]
self._params = restored["params"]
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
except Exception as e:
raise RuntimeError(
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
) from e
# Compile tokenize and detokenize functions
self._tokenize_fn = jax.jit(
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
)
self._detokenize_fn = jax.jit(
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
)
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
try:
import jax
# Move computation to CPU and compile on-demand
device = jax.devices("cpu")[0]
with jax.default_device(device):
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
except Exception as e:
logging.warning(f"Error decoding FSQ: {e}")
return np.zeros((action_horizon, action_dim))
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens

View File

@@ -93,7 +93,7 @@ def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathl
logger.info(f"Downloading {url} to {local_path}")
scratch_path = local_path.with_suffix(".partial")
if _is_openpi_url(url):
if _is_openpi_url(url) or _is_openpi_simeval_url(url):
# Download without credentials.
_download_boto3(
url,
@@ -299,6 +299,9 @@ def _is_openpi_url(url: str) -> bool:
"""Check if the url is an OpenPI S3 bucket url."""
return url.startswith("s3://openpi-assets/")
def _is_openpi_simeval_url(url: str) -> bool:
"""Check if the url is an OpenPI simeval S3 bucket url."""
return url.startswith("s3://openpi-assets-simeval/")
def _get_mtime(year: int, month: int, day: int) -> float:
"""Get the mtime of a given date at midnight UTC."""

View File

@@ -102,6 +102,8 @@ class ModelTransformFactory(GroupFactory):
# If provided, will determine the default prompt that be used by the model.
default_prompt: str | None = None
fast_model_tokenizer: Any | None = None
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
match model_config.model_type:
@@ -116,17 +118,21 @@ class ModelTransformFactory(GroupFactory):
],
)
case _model.ModelType.PI0_FAST:
tokenizer_cls = (
_tokenizer.FASTTokenizer if self.fast_model_tokenizer is None else self.fast_model_tokenizer
)
tokenizer_kwargs = {} if self.fast_model_tokenizer_kwargs is None else self.fast_model_tokenizer_kwargs
return _transforms.Group(
inputs=[
_transforms.InjectDefaultPrompt(self.default_prompt),
_transforms.ResizeImages(224, 224),
_transforms.TokenizeFASTInputs(
_tokenizer.FASTTokenizer(model_config.max_token_len),
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
),
],
outputs=[
_transforms.ExtractFASTActions(
_tokenizer.FASTTokenizer(model_config.max_token_len),
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
action_horizon=model_config.action_horizon,
action_dim=model_config.action_dim,
)
@@ -470,6 +476,140 @@ _CONFIGS = [
),
),
),
TrainConfig(
name="pi0_droid_jointpos",
model=pi0.Pi0Config(action_horizon=10),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)), droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
name="pi0_fast_droid_jointpos",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
droid_policy.DroidOutputs(),
],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
name="paligemma_binning_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
),
),
TrainConfig(
name="paligemma_binning_droid_jointpos",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
droid_policy.DroidOutputs(),
],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
),
),
TrainConfig(
name="paligemma_fast_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
name="paligemma_fast_specialist_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.FASTTokenizer,
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
),
),
),
TrainConfig(
name="paligemma_vq_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.FSQTokenizer,
fast_model_tokenizer_kwargs={
"fsq_tokenizer_path": "s3://openpi-assets-simeval/tokenizers/droid_fsq_tokenizer"
},
),
),
),
TrainConfig(
name="paligemma_diffusion_droid",
model=pi0.Pi0Config(action_horizon=10, action_dim=8),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
#
# Fine-tuning Libero configs.
#