multi-node openpi commit

This commit is contained in:
Leon998
2026-03-17 23:05:23 +08:00
parent 28833f0c0f
commit 7411e0e004
156 changed files with 33951 additions and 1 deletions

View File

@@ -0,0 +1,17 @@
import os
import pynvml
import pytest
def set_jax_cpu_backend_if_no_gpu() -> None:
try:
pynvml.nvmlInit()
pynvml.nvmlShutdown()
except pynvml.NVMLError:
# No GPU found.
os.environ["JAX_PLATFORMS"] = "cpu"
def pytest_configure(config: pytest.Config) -> None:
set_jax_cpu_backend_if_no_gpu()

View File

@@ -0,0 +1,459 @@
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma adaptation for Pi, taken from big_vision.
We follow this einsum axis naming convention:
B: batch
T: query length
S: k/v length
N: num query heads
K: num k/v heads
G: num query heads per k/v head
H: head dim
D: d_model ("features")
"""
from collections.abc import Sequence
import dataclasses
from typing import Literal, TypeAlias
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import openpi.models.lora as lora
import openpi.shared.array_typing as at
import openpi.training.sharding as sharding
PALIGEMMA_VOCAB_SIZE = 257_152
@dataclasses.dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
def get_config(variant: Variant) -> Config:
"""Returns config for specified gemma variant."""
if variant == "dummy":
return Config(
width=64,
depth=4,
mlp_dim=128,
num_heads=8,
num_kv_heads=1,
head_dim=16,
)
if variant == "gemma_300m":
# 311M params
return Config(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
if variant == "gemma_2b":
return Config(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
if variant == "gemma_2b_lora":
return Config(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
)
if variant == "gemma_300m_lora":
# 311M params
return Config(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
)
raise ValueError(f"Unknown variant: {variant}")
@at.typecheck
class RMSNorm(nn.Module):
@nn.compact
def __call__(self, x, cond):
dtype = x.dtype # original dtype, could be half-precision
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
if cond is None:
# regular RMSNorm
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
normed_inputs = normed_inputs * (
1 + scale
) # scale by learned parameter in float32 (matches Flax implementation)
return normed_inputs.astype(dtype), None # return in original dtype
# adaptive RMSNorm
modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
return normed_inputs.astype(dtype), gate
@at.typecheck
class Embedder(nn.Module):
"""Embedder module."""
vocab_size: int
embed_dim: int
def setup(self):
self.input_embedding_table = self.param(
"input_embedding",
nn.initializers.normal(),
(self.vocab_size, self.embed_dim),
)
def encode(self, x):
x = self.input_embedding_table[(x,)]
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
return x
def decode(self, x):
return jnp.dot(x, self.input_embedding_table.T)
@at.typecheck
class Attention(nn.Module):
"""Attention module."""
configs: Sequence[Config]
@nn.compact
def __call__(self, xs, positions, attn_mask, kv_cache):
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
qkvs = []
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is None:
continue
if config.num_kv_heads == config.num_heads:
qkv_einsum = lora.Einsum(
shape=(3, config.num_heads, config.width, config.head_dim),
name=_name("qkv_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=config.lora_configs.get("attn"),
)
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
else:
q_einsum = lora.Einsum(
shape=(config.num_heads, config.width, config.head_dim),
name=_name("q_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=config.lora_configs.get("attn"),
)
q = q_einsum("BTD,NDH->BTNH", x)
kv_einsum = lora.Einsum(
shape=(2, config.num_kv_heads, config.width, config.head_dim),
name=_name("kv_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=config.lora_configs.get("attn"),
)
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
qkvs.append((q, k, v))
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
q = _apply_rope(q, positions=positions)
q *= self.configs[0].head_dim ** -0.5
k = _apply_rope(k, positions=positions)
# should still be half-precision here (if input was half-precision)
assert q.dtype == k.dtype == v.dtype == dtype
if kv_cache is not None:
cache_k, cache_v = kv_cache
k = jnp.concatenate([cache_k, k], axis=1)
v = jnp.concatenate([cache_v, v], axis=1)
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
raise ValueError(
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
)
# big_neg = jnp.finfo(logits.dtype).min
big_neg = -2.3819763e38 # See gemma/modules.py
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
out = []
start = 0
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is not None:
end = start + x.shape[1]
out_einsum = lora.Einsum(
shape=(config.num_heads, config.head_dim, config.width),
name=_name("attn_vec_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
lora_config=config.lora_configs.get("attn"),
)
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
start = end
else:
out.append(None)
return out, (k, v)
@at.typecheck
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
w_gating = self.param(
"gating_einsum",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
(2, self.features, self.hidden_dim),
).astype(dtype)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
"linear",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
(self.hidden_dim, self.features),
).astype(dtype)
outputs = jnp.dot(activations, w_linear)
assert outputs.dtype == dtype
return outputs
@at.typecheck
class Block(nn.Module):
"""Transformer block."""
configs: tuple[Config, ...]
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = ()
@nn.compact
def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
xs = sharding.activation_sharding_constraint(xs)
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
attn = Attention(configs=self.configs, name="attn")
pre_attn = []
gates = []
for i, x in enumerate(xs):
if x is not None:
x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
pre_attn.append(x)
gates.append(gate if x is not None else None)
pre_attn = sharding.activation_sharding_constraint(pre_attn)
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
post_attn = sharding.activation_sharding_constraint(post_attn)
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
xs = sharding.activation_sharding_constraint(xs)
out = []
gates = []
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is not None:
x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
x = lora.FeedForward( # noqa: PLW2901
features=config.width,
hidden_dim=config.mlp_dim,
name=_name("mlp", i),
lora_config=config.lora_configs.get("ffn"),
)(x)
out.append(x)
gates.append(gate if x is not None else None)
out = sharding.activation_sharding_constraint(out)
out = jax.tree.map(lambda x: drop(x, deterministic), out)
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
xs = sharding.activation_sharding_constraint(xs)
return xs, kv_cache
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
@at.typecheck
class Module(nn.Module):
"""Transformer model, supporting a mixture of different weights for different tokens."""
configs: Sequence[Config] # list of configs, one for each expert
embed_dtype: str
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
adarms: bool = False
def setup(self):
# all experts must have the same depth
assert all(config.depth == self.configs[0].depth for config in self.configs)
self.embedder = Embedder(
vocab_size=PALIGEMMA_VOCAB_SIZE,
embed_dim=self.configs[0].width, # embedder for first expert only
name="embedder",
)
block_cls = nn.remat(
Block,
prevent_cse=False,
static_argnums=(5,), # 0=self, 6=deterministic
policy=jax.checkpoint_policies.nothing_saveable,
)
self.layers = nn.scan(
block_cls,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=(
0,
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
length=self.configs[0].depth,
)(
configs=self.configs,
dropout=self.dropout,
dropout_bdims=self.dropout_bdims,
)
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
@at.typecheck
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
return self.embedder.encode(tokens).astype(self.embed_dtype)
@at.typecheck
def __call__(
self,
# list of token arrays, one for each expert, or None if that expert should not be run
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
positions: at.Int[at.Array, "b t"],
mask: at.Bool[at.Array, "b t s"],
adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
*,
kv_cache: KVCache | None = None,
deterministic: bool = True,
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
mask = jnp.asarray(mask)[:, None, :, :]
if adarms_cond is None:
adarms_cond = [None] * len(self.configs)
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
return [
f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
], kv_cache
def init(self, use_adarms: Sequence[bool]):
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
self(
[jnp.zeros((1, 1, c.width)) for c in self.configs],
jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
)
def _apply_rope(x, *, positions, max_wavelength=10_000):
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
timescale = max_wavelength**freq_exponents
radians = positions[..., None] / timescale[None, None, :]
radians = radians[..., None, :]
assert radians.dtype == jnp.float32
# radians.shape = [...,L,1,d=D/2]
sin, cos = jnp.sin(radians), jnp.cos(radians)
x1, x2 = jnp.split(x, 2, axis=-1)
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
assert res.dtype == jnp.float32
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
# here.
return res.astype(x.dtype)
def _name(name, i):
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
# and the action expert.
if i == 0:
return name
return f"{name}_{i}"
def _gated_residual(x, y, gate):
assert (x is None) == (y is None)
if x is None:
return None
if gate is None:
return x + y
return x + y * gate

View File

@@ -0,0 +1,437 @@
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
Used for FAST autoregressive policies.
"""
import dataclasses
from typing import Literal, TypeAlias
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import openpi.models.lora as lora
import openpi.shared.array_typing as at
Variant = Literal["gemma_2b", "gemma_2b_lora"]
def get_config(variant):
"""Returns config for specified gemma variant."""
if variant == "gemma_2b":
return ml_collections.ConfigDict(
{
"variant": variant,
"width": 2048,
"depth": 18,
"mlp_dim": 16_384,
"num_heads": 8,
"num_kv_heads": 1,
"head_dim": 256,
"norm_eps": 1e-6,
"vocab_size": 257_152,
"scan": True,
"remat_policy": "nothing_saveable",
}
)
if variant == "gemma_2b_lora":
return ml_collections.ConfigDict(
{
"variant": variant,
"width": 2048,
"depth": 18,
"mlp_dim": 16_384,
"num_heads": 8,
"num_kv_heads": 1,
"head_dim": 256,
"norm_eps": 1e-6,
"vocab_size": 257_152,
"scan": True,
"remat_policy": "nothing_saveable",
"lora_configs": {
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
},
}
)
raise ValueError(f"Unknown variant: {variant}")
@at.typecheck
class Einsum(nn.Module):
shape: tuple[int, ...]
@nn.compact
def __call__(self, eqn, x):
dtype = x.dtype # original dtype, could be half-precision
w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
return jnp.einsum(eqn, x, w)
@at.typecheck
class RMSNorm(nn.Module):
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
normed_inputs = normed_inputs * (
1 + scale
) # scale by learned parameter in float32 (matches Flax implementation)
return normed_inputs.astype(dtype) # return in original dtype
@at.typecheck
class Embedder(nn.Module):
"""Embedder module."""
vocab_size: int
embed_dim: int
def setup(self):
self.input_embedding_table = self.param(
"input_embedding",
nn.initializers.zeros_init(),
(self.vocab_size, self.embed_dim),
)
def encode(self, x):
x = self.input_embedding_table[(x,)]
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
return x
def decode(self, x):
return jnp.dot(x, self.input_embedding_table.T)
@at.typecheck
class Attention(nn.Module):
"""Attention module."""
num_heads: int
num_kv_heads: int
features: int
head_dim: int
cache_dtype: str | None = None
lora_config: lora.LoRAConfig | None = None
def setup(self):
if self.num_kv_heads == self.num_heads:
self.qkv_einsum = lora.Einsum(
shape=(3, self.num_heads, self.features, self.head_dim),
name="qkv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
else:
self.q_einsum = lora.Einsum(
shape=(self.num_heads, self.features, self.head_dim),
name="q_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
self.kv_einsum = lora.Einsum(
shape=(2, self.num_kv_heads, self.features, self.head_dim),
name="kv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
self.attn_vec_einsum = lora.Einsum(
shape=(self.num_heads, self.head_dim, self.features),
name="attn_vec_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
def _init_cache(self, k, v, cache_size):
"""Initialize KV cache"""
prefill_len = k.shape[1]
pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
cache_dtype = self.cache_dtype or k.dtype
k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
return idx, k_cache, v_cache
def _update_cache(self, k, v, idx, k_cache, v_cache):
"""Update KV cache with new values"""
assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
indices = (0, idx[0], 0, 0)
cache_dtype = self.cache_dtype or k.dtype
k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
idx_new = idx + 1
return idx_new, k_new, v_new
@nn.compact
def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
dtype = x.dtype # original dtype, could be half-precision
if self.num_kv_heads == self.num_heads:
q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
else:
q = self.q_einsum("BTD,NDH->BTNH", x)
k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
q = _apply_rope(q, positions=positions) # promotes to float32
q *= self.head_dim**-0.5
k = _apply_rope(k, positions=positions) # promotes to float32
if kv_cache is None:
idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
else:
idx, k_cache, v_cache = kv_cache
idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
k, v = k_cache, v_cache
kv_cache = (idx, k_cache, v_cache)
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
raise ValueError(
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
)
# big_neg = jnp.finfo(logits.dtype).min
big_neg = -2.3819763e38 # See gemma/modules.py
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
@at.typecheck
class Block(nn.Module):
"""Transformer block."""
num_heads: int
num_kv_heads: int
embed_dim: int
head_dim: int
hidden_dim: int
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = ()
cache_dtype: str | None = None
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
def setup(self):
self.pre_attention_norm = RMSNorm()
self.attn = Attention(
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
features=self.embed_dim,
head_dim=self.head_dim,
cache_dtype=self.cache_dtype,
lora_config=self.lora_configs.get("attn"),
)
self.pre_ffw_norm = RMSNorm()
self.mlp = lora.FeedForward(
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
)
if self.dropout:
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
else:
self.drop = lambda x, _: x
def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
inputs_normalized = self.pre_attention_norm(x)
attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
attn_output = self.drop(attn_output, deterministic)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)
outputs = self.mlp(attn_output)
outputs = self.drop(outputs, deterministic)
outputs = residual + outputs
return outputs, kv_cache
KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
@at.typecheck
class Module(nn.Module):
"""gemma model."""
variant: str
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
norm_eps: float
vocab_size: int
embed_dtype: str
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
cache_dtype: str | None = None
scan: bool = False
remat_policy: str = "none"
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
@nn.compact
def __call__(
self,
tokens=None,
embedded_prefix=None,
embed_only=False, # noqa: FBT002
pre_logits=None,
positions=None,
mask=None,
decode=False, # noqa: FBT002
kv_cache=None,
deterministic=True, # noqa: FBT002
return_prelogits=False, # noqa: FBT002
):
"""Embed only, or complete forward pass.
Args:
tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
embedded_prefix: Optional prefix that is already embedded.
embed_only: Whether to compute embeddings only.
pre_logits: If present computes logits from pre_logits and returns.
positions: Optional `[B, T]` allows to specify the absolute position of
the tokens.
mask: Optional attention mask `[B, T, S]`.
decode: Whether to use kv-cache. Caller must pass masks and positions.
deterministic: Forwarded to all dropout layers.
return_prelogits: Whether to return the pre-logits.
Returns:
If `embed_only=False`, then `(logits, out)` will be returned.
If `embed_only=True`, then the embeddings will be returned.
If `return_prelogits=True`, then the pre-logits will be returned.
"""
out = {}
embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
if pre_logits is not None:
x = out["pre_logits"] = pre_logits
logits = out["logits"] = embedder.decode(x)
return logits, out
x = []
if embedded_prefix is not None:
x.append(embedded_prefix)
if tokens is not None:
x.append(embedder.encode(tokens))
x = jnp.concatenate(x, axis=-2)
x = x.astype(self.embed_dtype)
batch_size, seq_len, width = x.shape
if embed_only:
return x
if decode:
assert positions is not None and mask is not None, ( # noqa: PT018
"Must explicitly pass positions and mask for decoding."
)
if positions is None:
positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
if mask is None:
mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
if mask.ndim == 3:
mask = mask[:, None, :, :]
cache_size = max(seq_len, mask.shape[-1])
assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
if self.remat_policy == "none":
block_cls = Block
else:
block_cls = nn.remat(
Block,
prevent_cse=not self.scan,
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
policy=getattr(jax.checkpoint_policies, self.remat_policy),
)
block_kw = {
"num_heads": self.num_heads,
"head_dim": self.head_dim,
"num_kv_heads": self.num_kv_heads,
"embed_dim": width,
"hidden_dim": self.mlp_dim,
"dropout": self.dropout,
"dropout_bdims": self.dropout_bdims,
"cache_dtype": self.cache_dtype,
"lora_configs": self.lora_configs,
}
layers = self.scope.push("layers")
blocks = [
nn.scan(
block_cls,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
length=self.depth,
)(parent=layers, **block_kw)
]
for block in blocks:
x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
out["encoded"] = x
x = RMSNorm(name="final_norm")(x)
out["pre_logits"] = x
if return_prelogits:
return x, kv_cache, out
x = embedder.decode(x)
out["logits"] = x
return x, kv_cache, out
def init(self):
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
self(jnp.zeros((1, 1), dtype=jnp.int32))
def _apply_rope(x, *, positions, max_wavelength=10_000):
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
timescale = max_wavelength**freq_exponents
radians = positions[..., None] / timescale[None, None, :]
radians = radians[..., None, :]
assert radians.dtype == jnp.float32
# radians.shape = [...,L,1,d=D/2]
sin, cos = jnp.sin(radians), jnp.cos(radians)
x1, x2 = jnp.split(x, 2, axis=-1)
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
assert res.dtype == jnp.float32
return res

View File

@@ -0,0 +1,148 @@
import math
import re
import flax.linen as nn
import flax.struct as struct
import jax.numpy as jnp
import openpi.shared.array_typing as at
@struct.dataclass
class LoRAConfig:
"""Configuration for LoRA."""
# LoRA rank.
rank: int
# LoRA scaling factor.
alpha: float = 1.0
# Initialization function for LoRA parameters.
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
rslora: bool = False
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
axes: tuple[int, int] = (-2, -1)
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
label: str = "L"
@property
def scaling_value(self) -> float:
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
class Einsum(nn.Module):
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
# Shape of the weight.
shape: tuple[int, ...]
# Initialization function for the weight.
init_fn: nn.initializers.Initializer = nn.initializers.zeros
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w = self.param("w", self.init_fn, self.shape)
if config := self.lora_config:
# Setup LoRA parameters.
shape_a, shape_b = list(self.shape), list(self.shape)
shape_a[config.axes[1]] = config.rank
shape_b[config.axes[0]] = config.rank
self.w_a = self.param("lora_a", config.init_fn, shape_a)
self.w_b = self.param("lora_b", config.init_fn, shape_b)
@nn.compact
def __call__(self, eqn: str, x):
dtype = x.dtype # original dtype, could be half-precision
result = jnp.einsum(eqn, x, self.w.astype(dtype))
if config := self.lora_config:
eqn_a, eqn_b = self._make_lora_eqns(eqn)
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
result = result + lora * config.scaling_value
return result
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
if "L" in eqn:
raise ValueError(f"L already in eqn: {eqn}")
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
raise ValueError(f"Unsupported einsum eqn: {eqn}")
lhs, rhs, out = m.groups()
assert self.lora_config is not None
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
label = self.lora_config.label
a_rhs = rhs.replace(b_label, label)
a_out = out.replace(b_label, label)
eqn_a = f"{lhs},{a_rhs}->{a_out}"
b_rhs = rhs.replace(a_label, label)
eqn_b = f"{a_out},{b_rhs}->{out}"
return eqn_a, eqn_b
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w_gating = self.param(
"gating_einsum",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
(2, self.features, self.hidden_dim),
)
self.w_linear = self.param(
"linear",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
(self.hidden_dim, self.features),
)
self.w_gating_lora = None
self.w_linear_lora = None
if self.lora_config:
# Setup LoRA parameters.
# TODO: follow up with a simplified init_fn api.
self.w_gating_lora = (
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
self.param(
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
),
)
self.w_linear_lora = (
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
)
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
ff_gate = self._dot(
x,
self.w_gating[0],
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
)
gate_value = nn.gelu(ff_gate)
ff1 = self._dot(
x,
self.w_gating[1],
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
)
activations = gate_value * ff1
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
assert outputs.dtype == dtype
return outputs
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
base = jnp.dot(x, w.astype(x.dtype))
if lora_weights is None:
return base
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))

View File

@@ -0,0 +1,94 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import openpi.models.lora as lora
def test_lora_einsum_params_shape():
shape = (3, 8, 32, 4) # (3KDH)
einsum = lora.Einsum(shape)
lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
key = jax.random.key(0)
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
eqn = "BSD,3KDH->3BSKH"
# Ensure that lora parameters are not initialized when LoRA is not used.
params = einsum.init(key, eqn, x)
assert "lora_a" not in params["params"]
assert "lora_b" not in params["params"]
# Check that default axes work.
params_lora0 = lora0.init(key, eqn, x)
assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
# Check that user provided axes work.
params_lora1 = lora1.init(key, eqn, x)
assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
def test_lora_einsum_same_output():
shape = (3, 8, 32, 4) # (3KDH)
einsum = lora.Einsum(shape)
einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
key = jax.random.key(0)
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
eqn = "BSD,3KDH->3BSKH"
params = einsum.init(key, eqn, x)
output = einsum.apply(params, eqn, x)
params_lora = einsum_lora.init(key, eqn, x)
output_lora = einsum_lora.apply(params_lora, eqn, x)
# Results are the same since the LoRA parameters are initialized to zeros.
assert jnp.allclose(output, output_lora)
def test_lora_ffn_params_shape():
ffn = lora.FeedForward(features=8, hidden_dim=32)
ffn_lora = lora.FeedForward(
features=8,
hidden_dim=32,
lora_config=lora.LoRAConfig(rank=2),
)
key = jax.random.key(0)
x = jax.random.normal(key, (2, 8))
params = ffn.init(key, x)
assert params["params"]["gating_einsum"].shape == (2, 8, 32)
assert params["params"]["linear"].shape == (32, 8)
params_lora = ffn_lora.init(key, x)
assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
assert params_lora["params"]["linear"].shape == (32, 8)
assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
def test_lora_ffn_same_output():
ffn = lora.FeedForward(features=8, hidden_dim=32)
ffn_lora = lora.FeedForward(
features=8,
hidden_dim=32,
lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
)
key = jax.random.key(0)
x = jax.random.normal(key, (2, 8))
params = ffn.init(key, x)
output = ffn.apply(params, x)
params_lora = ffn_lora.init(key, x)
output_lora = ffn_lora.apply(params_lora, x)
assert jnp.allclose(output, output_lora)

View File

@@ -0,0 +1,332 @@
import abc
from collections.abc import Sequence
import dataclasses
import enum
import logging
import pathlib
from typing import Generic, TypeVar
import augmax
from flax import nnx
from flax import struct
from flax import traverse_util
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
import safetensors
import torch
from openpi.models_pytorch import pi0_pytorch
from openpi.shared import image_tools
import openpi.shared.array_typing as at
logger = logging.getLogger("openpi")
# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
class ModelType(enum.Enum):
"""Supported model types."""
PI0 = "pi0"
PI0_FAST = "pi0_fast"
PI05 = "pi05"
# The model always expects these images
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
# This may need change if we release a small model.
IMAGE_RESOLUTION = (224, 224)
# Data format
#
# Data transforms produce the model input as a nested dictionary which is later converted
# into `Obesrvation` and `Actions` objects. See below.
#
# In the dictory form, this data should look like:
# {
# # Observation data.
# "image": {
# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
# ... # Additional camera views
# },
# "image_mask": {
# "base_0_rgb": bool[*b], # True if image is valid
# ... # Masks for additional views
# },
# "state": float32[*b, s], # Low-dimensional robot state
# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
#
# # Actions data.
# "actions": float32[*b ah ad]
# }
# where:
# *b = batch dimensions
# h,w = image height/width
# s = state dimension
# l = sequence length
#
@at.typecheck
@struct.dataclass
class Observation(Generic[ArrayT]):
"""Holds observations, i.e., inputs to the model.
See `Observation.from_dict` to see the expected dictionary form. This is the format
that should be produced by the data transforms.
"""
# Images, in [-1, 1] float32.
images: dict[str, at.Float[ArrayT, "*b h w c"]]
# Image masks, with same keys as images.
image_masks: dict[str, at.Bool[ArrayT, "*b"]]
# Low-dimensional robot state.
state: at.Float[ArrayT, "*b s"]
# Tokenized prompt.
tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
# Tokenized prompt mask.
tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
# pi0-fast model specific fields.
# Token auto-regressive mask (for FAST autoregressive model).
token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
# Token loss mask (for FAST autoregressive model).
token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
@classmethod
def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
# If images are uint8, convert them to [-1, 1] float32.
for key in data["image"]:
if data["image"][key].dtype == np.uint8:
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
return cls(
images=data["image"],
image_masks=data["image_mask"],
state=data["state"],
tokenized_prompt=data.get("tokenized_prompt"),
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
token_ar_mask=data.get("token_ar_mask"),
token_loss_mask=data.get("token_loss_mask"),
)
def to_dict(self) -> at.PyTree[ArrayT]:
"""Convert the Observation to a nested dict."""
result = dataclasses.asdict(self)
result["image"] = result.pop("images")
result["image_mask"] = result.pop("image_masks")
return result
# Defines the format of the actions. This field is included as "actions" inside the dictionary
# produced by the data transforms.
Actions = at.Float[ArrayT, "*b ah ad"]
def preprocess_observation(
rng: at.KeyArrayLike | None,
observation: Observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
) -> Observation:
"""Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
filling in a default image mask (if necessary).
"""
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for augmax.
image = image / 2.0 + 0.5
transforms = []
if "wrist" not in key:
height, width = image.shape[1:3]
transforms += [
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
augmax.Resize(width, height),
augmax.Rotate((-5, 5)),
]
transforms += [
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
]
sub_rngs = jax.random.split(rng, image.shape[0])
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
# Back to [-1, 1].
image = image * 2.0 - 1.0
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
else:
out_masks[key] = jnp.asarray(observation.image_masks[key])
return Observation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
token_ar_mask=observation.token_ar_mask,
token_loss_mask=observation.token_loss_mask,
)
@dataclasses.dataclass(frozen=True)
class BaseModelConfig(abc.ABC):
"""Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
method to create the corresponding model.
"""
# Action space dimension.
action_dim: int
# Action sequence length.
action_horizon: int
# Tokenized prompt maximum length.
max_token_len: int
@property
@abc.abstractmethod
def model_type(self) -> ModelType:
"""The model type."""
@abc.abstractmethod
def create(self, rng: at.KeyArrayLike) -> "BaseModel":
"""Create a new model, initializing parameters."""
def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
"""Create a model with the given parameters."""
model = nnx.eval_shape(self.create, jax.random.key(0))
graphdef, state = nnx.split(model)
if remove_extra_params:
params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
state.replace_by_pure_dict(params)
return nnx.merge(graphdef, state)
def load_pytorch(self, train_config, weight_path: str):
logger.info(f"train_config: {train_config}")
model = pi0_pytorch.PI0Pytorch(config=train_config.model)
safetensors.torch.load_model(model, weight_path)
return model
@abc.abstractmethod
def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
"""Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
def fake_obs(self, batch_size: int = 1) -> Observation:
observation_spec, _ = self.inputs_spec(batch_size=batch_size)
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
def fake_act(self, batch_size: int = 1) -> Actions:
_, action_spec = self.inputs_spec(batch_size=batch_size)
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
@dataclasses.dataclass
class BaseModel(nnx.Module, abc.ABC):
"""Base class for all model implementations. Specific models should inherit from this class. They should call
super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
"""
action_dim: int
action_horizon: int
max_token_len: int
@abc.abstractmethod
def compute_loss(
self,
rng: at.KeyArrayLike,
observation: Observation,
actions: Actions,
*,
train: bool = False,
) -> at.Float[at.Array, "*b ah"]: ...
@abc.abstractmethod
def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...
def restore_params(
params_path: pathlib.Path | str,
*,
restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
dtype: jnp.dtype | None = None,
sharding: jax.sharding.Sharding | None = None,
) -> at.Params:
"""Restores unstructured params PyTree from a checkpoint.
This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
well as pre-trained checkpoints released for openpi.
Args:
params_path: The local path to the checkpoint directory.
restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
Returns:
The restored params.
"""
params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path
if restore_type is jax.Array and sharding is None:
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
with ocp.PyTreeCheckpointer() as ckptr:
metadata = ckptr.metadata(params_path)
item = {"params": metadata["params"]}
params = ckptr.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree.map(
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
),
),
)["params"]
# If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
# added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
flat_params = traverse_util.flatten_dict(params)
if all(kp[-1] == "value" for kp in flat_params):
flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
return traverse_util.unflatten_dict(flat_params)

View File

@@ -0,0 +1,94 @@
from flax import nnx
import jax
import pytest
from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.models import pi0_fast
from openpi.shared import download
from openpi.shared import nnx_utils
def test_pi0_model():
key = jax.random.key(0)
config = pi0_config.Pi0Config()
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
def test_pi0_lora_model():
key = jax.random.key(0)
config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
def test_pi0_fast_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig()
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
def test_pi0_fast_lora_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
lora_filter = nnx_utils.PathRegex(".*lora.*")
model_state = nnx.state(model)
lora_state_elems = list(model_state.filter(lora_filter))
assert len(lora_state_elems) > 0
@pytest.mark.manual
def test_model_restore():
key = jax.random.key(0)
config = pi0_config.Pi0Config()
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
model = config.load(
_model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
)
loss = model.compute_loss(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = model.sample_actions(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)

View File

@@ -0,0 +1,279 @@
import logging
import einops
import flax.nnx as nnx
import flax.nnx.bridge as nnx_bridge
import jax
import jax.numpy as jnp
from typing_extensions import override
from openpi.models import model as _model
from openpi.models import pi0_config
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
logger = logging.getLogger("openpi")
def make_attn_mask(input_mask, mask_ar):
"""Adapted from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
it and false where it shares the same attention mask as the previous token.
"""
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
cumsum = jnp.cumsum(mask_ar, axis=1)
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
return jnp.logical_and(attn_mask, valid_mask)
@at.typecheck
def posemb_sincos(
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if embedding_dim % 2 != 0:
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
period = min_period * (max_period / min_period) ** fraction
sinusoid_input = jnp.einsum(
"i,j->ij",
pos,
1.0 / period * 2 * jnp.pi,
precision=jax.lax.Precision.HIGHEST,
)
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
class Pi0(_model.BaseModel):
def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
self.pi05 = config.pi05
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
# TODO: rewrite gemma in NNX. For now, use bridge.
llm = nnx_bridge.ToNNX(
_gemma.Module(
configs=[paligemma_config, action_expert_config],
embed_dtype=config.dtype,
adarms=config.pi05,
)
)
llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False])
img = nnx_bridge.ToNNX(
_siglip.Module(
num_classes=paligemma_config.width,
variant="So400m/14",
pool_type="none",
scan=True,
dtype_mm=config.dtype,
)
)
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
self.PaliGemma = nnx.Dict(llm=llm, img=img)
self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
if config.pi05:
self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
else:
self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
# This attribute gets automatically set by model.train() and model.eval().
self.deterministic = True
@at.typecheck
def embed_prefix(
self, obs: _model.Observation
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
input_mask = []
ar_mask = []
tokens = []
# embed images
for name in obs.images:
image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
tokens.append(image_tokens)
input_mask.append(
einops.repeat(
obs.image_masks[name],
"b -> b s",
s=image_tokens.shape[1],
)
)
# image tokens attend to each other
ar_mask += [False] * image_tokens.shape[1]
# add language (aka tokenized inputs)
if obs.tokenized_prompt is not None:
tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
tokens.append(tokenized_inputs)
input_mask.append(obs.tokenized_prompt_mask)
# full attention between image and language inputs
ar_mask += [False] * tokenized_inputs.shape[1]
tokens = jnp.concatenate(tokens, axis=1)
input_mask = jnp.concatenate(input_mask, axis=1)
ar_mask = jnp.array(ar_mask)
return tokens, input_mask, ar_mask
@at.typecheck
def embed_suffix(
self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
) -> tuple[
at.Float[at.Array, "b s emb"],
at.Bool[at.Array, "b s"],
at.Bool[at.Array, " s"],
at.Float[at.Array, "b emb"] | None,
]:
input_mask = []
ar_mask = []
tokens = []
if not self.pi05:
# add a single state token
state_token = self.state_proj(obs.state)[:, None, :]
tokens.append(state_token)
input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
# image/language inputs do not attend to state or actions
ar_mask += [True]
action_tokens = self.action_in_proj(noisy_actions)
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
if self.pi05:
# time MLP (for adaRMS)
time_emb = self.time_mlp_in(time_emb)
time_emb = nnx.swish(time_emb)
time_emb = self.time_mlp_out(time_emb)
time_emb = nnx.swish(time_emb)
action_expert_tokens = action_tokens
adarms_cond = time_emb
else:
# mix timestep + action information using an MLP (no adaRMS)
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
action_time_tokens = self.action_time_mlp_in(action_time_tokens)
action_time_tokens = nnx.swish(action_time_tokens)
action_time_tokens = self.action_time_mlp_out(action_time_tokens)
action_expert_tokens = action_time_tokens
adarms_cond = None
tokens.append(action_expert_tokens)
input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))
# image/language/state inputs do not attend to action tokens
ar_mask += [True] + ([False] * (self.action_horizon - 1))
tokens = jnp.concatenate(tokens, axis=1)
input_mask = jnp.concatenate(input_mask, axis=1)
ar_mask = jnp.array(ar_mask)
return tokens, input_mask, ar_mask, adarms_cond
@override
def compute_loss(
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
) -> at.Float[at.Array, "*b ah"]:
preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
batch_shape = actions.shape[:-2]
noise = jax.random.normal(noise_rng, actions.shape)
time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
time_expanded = time[..., None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# one big forward pass of prefix + suffix at once
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
attn_mask = make_attn_mask(input_mask, ar_mask)
positions = jnp.cumsum(input_mask, axis=1) - 1
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
)
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
return jnp.mean(jnp.square(v_t - u_t), axis=-1)
@override
def sample_actions(
self,
rng: at.KeyArrayLike,
observation: _model.Observation,
*,
num_steps: int | at.Int[at.Array, ""] = 10,
noise: at.Float[at.Array, "b ah ad"] | None = None,
) -> _model.Actions:
observation = _model.preprocess_observation(None, observation, train=False)
# note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
# distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
dt = -1.0 / num_steps
batch_size = observation.state.shape[0]
if noise is None:
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
# first fill KV cache with a forward pass of the prefix
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
positions = jnp.cumsum(prefix_mask, axis=1) - 1
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
def step(carry):
x_t, time = carry
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
observation, x_t, jnp.broadcast_to(time, batch_size)
)
# `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
# other
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
# `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
# prefix tokens
prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
# `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
# generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
assert full_attn_mask.shape == (
batch_size,
suffix_tokens.shape[1],
prefix_tokens.shape[1] + suffix_tokens.shape[1],
)
# `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[None, suffix_tokens],
mask=full_attn_mask,
positions=positions,
kv_cache=kv_cache,
adarms_cond=[None, adarms_cond],
)
assert prefix_out is None
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
return x_t + dt * v_t, time + dt
def cond(carry):
x_t, time = carry
# robust to floating-point error
return time >= -dt / 2
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
return x_0

View File

@@ -0,0 +1,108 @@
import dataclasses
from typing import TYPE_CHECKING
import flax.nnx as nnx
import jax
import jax.numpy as jnp
from typing_extensions import override
from openpi.models import model as _model
import openpi.models.gemma as _gemma
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
if TYPE_CHECKING:
from openpi.models.pi0 import Pi0
@dataclasses.dataclass(frozen=True)
class Pi0Config(_model.BaseModelConfig):
dtype: str = "bfloat16"
paligemma_variant: _gemma.Variant = "gemma_2b"
action_expert_variant: _gemma.Variant = "gemma_300m"
# Set the model specific defaults.
action_dim: int = 32
action_horizon: int = 50
max_token_len: int = None # type: ignore
# Pi05 has two differences from Pi0:
# - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
# - the action expert uses adaRMSNorm to inject the flow matching timestep
pi05: bool = False
# This config option is not used directly by the model, but it is read by the ModelTransformFactory.
discrete_state_input: bool = None # type: ignore
def __post_init__(self):
if self.max_token_len is None:
object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
if self.discrete_state_input is None:
object.__setattr__(self, "discrete_state_input", self.pi05)
@property
@override
def model_type(self) -> _model.ModelType:
if self.pi05:
return _model.ModelType.PI05
return _model.ModelType.PI0
@override
def create(self, rng: at.KeyArrayLike) -> "Pi0":
from openpi.models.pi0 import Pi0
return Pi0(self, rngs=nnx.Rngs(rng))
@override
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
with at.disable_typechecking():
observation_spec = _model.Observation(
images={
"base_0_rgb": image_spec,
"left_wrist_0_rgb": image_spec,
"right_wrist_0_rgb": image_spec,
},
image_masks={
"base_0_rgb": image_mask_spec,
"left_wrist_0_rgb": image_mask_spec,
"right_wrist_0_rgb": image_mask_spec,
},
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
)
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
return observation_spec, action_spec
def get_freeze_filter(self) -> nnx.filterlib.Filter:
"""Returns the freeze filter based on the model config."""
filters = []
has_lora = False
gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
if "lora" in self.paligemma_variant:
filters.append(
gemma_params_filter,
)
if "lora" not in self.action_expert_variant:
# If only freeze gemma params, exclude action expert params.
filters.append(
nnx.Not(action_expert_params_filter),
)
has_lora = True
elif "lora" in self.action_expert_variant:
filters.append(
action_expert_params_filter,
)
has_lora = True
if has_lora:
# If any lora is used, exclude all lora params.
filters.append(
nnx.Not(nnx_utils.PathRegex(".*lora.*")),
)
if not filters:
return nnx.Nothing
return nnx.All(*filters)

View File

@@ -0,0 +1,313 @@
import dataclasses
import logging
from typing import Any
import einops
import flax.nnx as nnx
import flax.nnx.bridge as nnx_bridge
import jax
import jax.numpy as jnp
from typing_extensions import override
from openpi.models import model as _model
import openpi.models.gemma_fast as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
logger = logging.getLogger("openpi")
PALIGEMMA_EOS_TOKEN = 1
def make_attn_mask(input_mask, mask_ar):
"""Adapted from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
it and false where it shares the same attention mask as the previous token.
"""
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
cumsum = jnp.cumsum(mask_ar, axis=1)
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
return jnp.logical_and(attn_mask, valid_mask)
@jax.vmap
def left_to_right_align(x, input_mask, attn_mask):
"""Converts input from left-align to right-aligned."""
# Due to vmap, this is operating in a single example (not batch level).
assert x.ndim == 2
assert input_mask.ndim == 1
assert attn_mask.ndim == 2
assert x.shape[0] == input_mask.shape[0]
assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1
x = jnp.roll(x, -seqlen, axis=0)
input_mask = jnp.roll(input_mask, -seqlen, axis=0)
attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
return x, input_mask, attn_mask
def put_along_last_axis(arr, indices, values):
"""Like np.put_along_axis(..., axis=-1), since jax is missing it."""
assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim)
onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot)
put_values = jnp.einsum("...i,...in->...n", values, onehot)
return jnp.where(put_mask, put_values, arr)
@dataclasses.dataclass(frozen=True)
class Pi0FASTConfig(_model.BaseModelConfig):
dtype: str = "bfloat16"
paligemma_variant: _gemma.Variant = "gemma_2b"
# Set the model specific defaults.
action_dim: int = 32
action_horizon: int = 32
max_token_len: int = 250
# Tokenizer for the fast model.
fast_model_tokenizer: Any | None = None
# Keyword arguments for the fast model tokenizer.
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
@property
@override
def model_type(self) -> _model.ModelType:
return _model.ModelType.PI0_FAST
@override
def create(self, rng: at.KeyArrayLike) -> "Pi0FAST":
return Pi0FAST(self, rngs=nnx.Rngs(rng))
@override
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
with at.disable_typechecking():
observation_spec = _model.Observation(
images={
"base_0_rgb": image_spec,
"base_1_rgb": image_spec,
"wrist_0_rgb": image_spec,
},
image_masks={
"base_0_rgb": image_mask_spec,
"base_1_rgb": image_mask_spec,
"wrist_0_rgb": image_mask_spec,
},
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_),
)
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
return observation_spec, action_spec
def get_freeze_filter(self) -> nnx.filterlib.Filter:
"""Returns the freeze filter based on the model config."""
if "lora" in self.paligemma_variant:
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
return nnx.Nothing
class Pi0FAST(_model.BaseModel):
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
paligemma_config = _gemma.get_config(config.paligemma_variant)
# TODO: rewrite gemma in NNX. For now, use bridge.
llm = nnx_bridge.ToNNX(
_gemma.Module(
**paligemma_config,
embed_dtype=config.dtype,
cache_dtype=config.dtype,
)
)
llm.lazy_init(rngs=rngs, method="init")
img = nnx_bridge.ToNNX(
_siglip.Module(
num_classes=paligemma_config.width,
variant="So400m/14",
pool_type="none",
scan=True,
dtype_mm=config.dtype,
)
)
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
self.PaliGemma = nnx.Dict(llm=llm, img=img)
@at.typecheck
def embed_inputs(
self, obs: _model.Observation
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]:
input_mask = []
ar_mask = []
token_embeddings = []
# embed images
for name in obs.images:
image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False)
token_embeddings.append(image_token_embeddings)
input_mask.append(
einops.repeat(
obs.image_masks[name],
"b -> b s",
s=image_token_embeddings.shape[1],
)
)
# image tokens attend to each other --> AR mask = 0
ar_mask.append(0 * input_mask[-1])
# add tokenized inputs
assert obs.tokenized_prompt is not None, "Tokenized prompt is required"
assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required"
assert obs.token_ar_mask is not None, "Token auto-regressive mask is required"
tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True)
token_embeddings.append(tokenized_inputs_embeddings)
input_mask.append(obs.tokenized_prompt_mask)
ar_mask.append(obs.token_ar_mask)
# return embeddings, input mask, and ar mask
return (
jnp.concatenate(token_embeddings, axis=1),
jnp.concatenate(input_mask, axis=1),
jnp.concatenate(ar_mask, axis=1),
)
@override
def compute_loss(
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
) -> at.Float[at.Array, "*b ah"]:
observation = _model.preprocess_observation(
rng, observation, train=train, image_keys=list(observation.images.keys())
)
# Compute inputs: one big forward pass of prefix + suffix at once
input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
attn_mask = make_attn_mask(input_mask, ar_mask)
# Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
targets = jax.nn.one_hot(
observation.tokenized_prompt[:, 1:],
self.PaliGemma.llm.module.vocab_size,
)
# Each input predicts *next* token, so we don't input the last token.
pre_logits, _, _ = self.PaliGemma.llm(
embedded_prefix=input_token_embeddings[:, :-1],
mask=attn_mask[:, :-1, :-1],
return_prelogits=True,
)
# Only decode logits for the target tokens to save memory
# (decoding matmul is large because it is a seq_len x vocab_size dense layer).
logits, _ = self.PaliGemma.llm(
pre_logits=pre_logits[:, -targets.shape[1] :],
)
logp = jax.nn.log_softmax(logits, axis=-1)
# Compute CE loss on token targets
assert observation.token_loss_mask is not None, "Token loss mask is required"
loss_mask = observation.token_loss_mask[:, 1:]
token_pplx = jnp.sum(targets * logp, axis=-1)
return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)
@override
def sample_actions(
self,
rng: at.KeyArrayLike,
observation: _model.Observation,
*,
max_decoding_steps: int | at.Int[at.Array, ""] = 256,
temperature: float = 0.0,
) -> _model.Actions:
# TODO: this is a hack to get the image keys.
observation = _model.preprocess_observation(
None, observation, train=False, image_keys=list(observation.images.keys())
)
# embed inputs
prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
# left to right align all input token sequences
prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
prefix_token_embeddings, prefix_mask, prefix_attn_mask
)
prefill_size = prefix_token_embeddings.shape[1]
prefill_len = jnp.sum(prefix_mask, axis=-1)
prefix_start = prefill_size - prefill_len
# first fill KV cache with a forward pass of the prefix
# pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)
prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
prefix_logits, kv_cache, _ = self.PaliGemma.llm(
embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
)
# prepare decoding -- final logit decodes the first token
last_logit = prefix_logits[:, -1:]
output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
def step(carry):
rng, last_logit, output_tokens, cache, _, step = carry
# Sample token from last logit
# Split RNG for this step
rng, rng_step = jax.random.split(rng)
token = jax.lax.cond(
temperature > 0.0,
lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
lambda _: jnp.argmax(last_logit, axis=-1),
operand=None,
)
output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
# Check for early stopping --> stop if all batch elements have EOS token
has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
all_eos = jnp.all(has_eos)
# Decode one step
token_embedding = self.PaliGemma.llm(token, embed_only=True)
positions = prefill_len[:, None] + step + 1
mask = jnp.logical_and(
jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
< (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),
)
last_logit, kv_cache, _ = self.PaliGemma.llm(
embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
)
return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1
def cond(carry):
_, _, _, _, all_eos, step = carry
return (~all_eos) & (step < max_decoding_steps)
# Use lax.while_loop so we can jit the full decoding loop.
_, _, output_tokens, _, _, _ = jax.lax.while_loop(
cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)
)
return output_tokens

View File

@@ -0,0 +1,46 @@
import flax.nnx as nnx
import jax
import openpi.models.pi0_config as _pi0_config
def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:
abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
freeze_filter = config.get_freeze_filter()
return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()
def test_pi0_full_finetune():
config = _pi0_config.Pi0Config()
state = _get_frozen_state(config)
assert len(state) == 0
def test_pi0_gemma_lora():
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
state = _get_frozen_state(config)
assert len(state) == 9
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)
assert all("_1" not in p for p in state)
def test_pi0_action_expert_lora():
config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora")
state = _get_frozen_state(config)
# excluding embedder, rest of the params should be same as gemma_lora.
assert len(state) == 8
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)
# all frozen params should have _1 in their path since it's the action expert.
assert all(any("_1" in p for p in path) for path in state)
def test_pi0_all_lora():
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora")
state = _get_frozen_state(config)
# sum of gemma_lora and action_expert_lora's frozen params.
assert len(state) == 17
assert all("lora" not in p for p in state)
assert all("llm" in p for p in state)

View File

@@ -0,0 +1,373 @@
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A refactored and simplified ViT adoptation for Pi, taken from big_vision."""
from collections.abc import Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import openpi.training.sharding as sharding
def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
"""Follows the MoCo v3 logic."""
y, x = jnp.mgrid[:h, :w]
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
omega = jnp.arange(width // 4) / (width // 4 - 1)
omega = 1.0 / (temperature**omega)
y = jnp.einsum("m,d->md", y.flatten(), omega)
x = jnp.einsum("m,d->md", x.flatten(), omega)
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
return jnp.asarray(pe, dtype)[None, :, :]
def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
if typ == "learn":
return self.param(
name,
nn.initializers.normal(stddev=1 / np.sqrt(width)),
(1, np.prod(seqshape), width),
dtype,
)
if typ == "sincos2d":
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
raise ValueError(f"Unknown posemb type: {typ}")
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int | None = None # Defaults to 4x input dim
dropout: float = 0.0
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x, deterministic=True): # noqa: FBT002
"""Applies Transformer MlpBlock module."""
inits = {
"kernel_init": nn.initializers.xavier_uniform(),
"bias_init": nn.initializers.normal(stddev=1e-6),
}
_, _, d = x.shape # n,l,d
x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout)(x, deterministic)
return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
dropout: float = 0.0
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x, deterministic=True): # noqa: FBT002
out = {}
x = sharding.activation_sharding_constraint(x)
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
y = out["sa"] = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=deterministic,
dtype=self.dtype_mm,
)(y, y)
y = sharding.activation_sharding_constraint(y)
y = nn.Dropout(rate=self.dropout)(y, deterministic)
x = out["+sa"] = x + y
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
y = out["mlp"] = MlpBlock(
mlp_dim=self.mlp_dim,
dropout=self.dropout,
dtype_mm=self.dtype_mm,
)(y, deterministic)
y = sharding.activation_sharding_constraint(y)
y = nn.Dropout(rate=self.dropout)(y, deterministic)
x = out["+mlp"] = x + y
x = sharding.activation_sharding_constraint(x)
return x, out
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
depth: int
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
dropout: float = 0.0
scan: bool = False
remat_policy: str = "nothing_saveable"
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x, deterministic=True): # noqa: FBT002
out = {}
if self.scan:
block = nn.remat(
Encoder1DBlock,
prevent_cse=False,
static_argnums=(2,), # 0=self, 2=deterministic
policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
)
x, scan_out = nn.scan(
block,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=nn.broadcast,
length=self.depth,
)(
name="encoderblock",
dtype_mm=self.dtype_mm,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
)(x, deterministic)
for lyr in range(self.depth):
out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out)
else:
# Input Encoder
for lyr in range(self.depth):
block_cur = Encoder1DBlock(
name=f"encoderblock_{lyr}",
dtype_mm=self.dtype_mm,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
)
x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
out["pre_ln"] = x # Alias for last block, but without the number in it.
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out
class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x):
n, _, d = x.shape # n,l,d
probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype)
probe = jnp.tile(probe, [n, 1, 1])
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
dtype=self.dtype_mm,
kernel_init=nn.initializers.xavier_uniform(),
)(probe, x)
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)
return x[:, 0]
class _Module(nn.Module):
"""ViT model."""
num_classes: int | None = None
patch_size: Sequence[int] = (16, 16)
width: int = 768
depth: int = 12
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
posemb: str = "learn" # Can also be "sincos2d"
rep_size: int | bool = False
dropout: float = 0.0
pool_type: str = "gap" # Can also be "map" or "tok"
head_zeroinit: bool = True
scan: bool = False
# or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
remat_policy: str = "nothing_saveable"
dtype_mm: str = "float32"
@nn.compact
def __call__(self, image, *, train=False):
out = {}
# Kevin edit: do patch extraction and posemb in float32,
# because I feel like it's a bit safer.
image = jnp.asarray(image, jnp.float32)
# Patch extraction
x = out["stem"] = nn.Conv(
self.width,
self.patch_size,
strides=self.patch_size,
padding="VALID",
name="embedding",
dtype=jnp.float32,
)(image)
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# Add posemb before adding extra token.
x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32)
if self.pool_type == "tok":
cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
n, _, c = x.shape # n,l,d
x = nn.Dropout(rate=self.dropout)(x, not train)
# Kevin edit: now cast back to dtype_mm (potentially half precision)
x = x.astype(self.dtype_mm)
x, out["encoder"] = Encoder(
depth=self.depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
scan=self.scan,
remat_policy=self.remat_policy,
dtype_mm=self.dtype_mm,
name="Transformer",
)(x, deterministic=not train)
encoded = out["encoded"] = x
if self.pool_type == "map":
x = out["head_input"] = MAPHead(
num_heads=self.num_heads,
mlp_dim=self.mlp_dim,
dtype=self.dtype_mm,
)(x)
elif self.pool_type == "gap":
x = out["head_input"] = jnp.mean(x, axis=1)
elif self.pool_type == "0":
x = out["head_input"] = x[:, 0]
elif self.pool_type == "tok":
x = out["head_input"] = x[:, 0]
encoded = encoded[:, 1:]
elif self.pool_type == "none":
pass
else:
raise ValueError(f"Unknown pool type: '{self.pool_type}'")
x_2d = jnp.reshape(encoded, [n, h, w, -1])
if self.rep_size:
rep_size = self.width if self.rep_size is True else self.rep_size
hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits")
# NOTE: In the past we did not include tanh in pre_logits.
# For few-shot, it should not matter much, as it whitens anyways.
x_2d = nn.tanh(hid(x_2d))
x = nn.tanh(hid(x))
out["pre_logits_2d"] = x_2d
out["pre_logits"] = x
if self.num_classes:
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw)
x_2d = out["logits_2d"] = head(x_2d)
x = out["logits"] = head(x)
return x, out
def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802
"""Factory function, because linen really don't like what I'm doing!"""
return _Module(num_classes, **{**decode_variant(variant), **kw})
def decode_variant(variant):
"""Converts a string like "B" or "B/32" into a params dict."""
if variant is None:
return {}
v, patch = variant, {}
if "/" in variant:
v, patch = variant.split("/")
patch = {"patch_size": (int(patch), int(patch))}
return {
# pylint:disable=line-too-long
# Reference: Table 2 of https://arxiv.org/abs/2106.04560.
"width": {
"mu": 32,
"Ti": 192,
"S": 384,
"M": 512,
"B": 768,
"L": 1024,
"So400m": 1152,
"H": 1280,
"g": 1408,
"g-opt": 1536,
"G": 1664,
"G-opt": 1536,
"e": 1792,
}[v],
"depth": {
"mu": 1,
"Ti": 12,
"S": 12,
"M": 12,
"B": 12,
"L": 24,
"So400m": 27,
"H": 32,
"g": 40,
"g-opt": 40,
"G": 48,
"G-opt": 48,
"e": 56,
}[v],
"mlp_dim": {
"mu": 128,
"Ti": 768,
"S": 1536,
"M": 2048,
"B": 3072,
"L": 4096,
"So400m": 4304,
"H": 5120,
"g": 6144,
"g-opt": 6144,
"G": 8192,
"G-opt": 8192,
"e": 15360,
}[v],
"num_heads": {
"mu": 2,
"Ti": 3,
"S": 6,
"M": 8,
"B": 12,
"L": 16,
"So400m": 16,
"H": 16,
"g": 16,
"g-opt": 16,
"G": 16,
"G-opt": 16,
"e": 16,
}[v],
# pylint:enable=line-too-long
**patch,
}

View File

@@ -0,0 +1,371 @@
import logging
import os
import jax
import numpy as np
import orbax.checkpoint as ocp
import sentencepiece
from transformers import AutoProcessor
import openpi.models.utils.fsq_tokenizer as fsq_tokenizer
import openpi.shared.download as download
class PaligemmaTokenizer:
def __init__(self, max_len: int = 48):
self._max_len = max_len
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ")
if state is not None:
# This is the Pi05 format, where the state is part of the discrete language input.
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
state_str = " ".join(map(str, discretized_state))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
tokens = self._tokenizer.encode(full_prompt, add_bos=True)
else:
# This is the Pi0 format, where the state is part of the continuous action expert input.
# tokenize "\n" separately as the "start of answer" token
tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
mask = [True] * tokens_len + padding
tokens = tokens + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
mask = [True] * self._max_len
return np.asarray(tokens), np.asarray(mask)
class FASTTokenizer:
def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"):
self._max_len = max_len
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
# Instantiate FAST tokenizer
self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
# Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
action_tokens = self._fast_tokenizer(actions[None])[0]
action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)
# Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
postfix_tokens = (
self._paligemma_tokenizer.encode("Action: ")
+ action_tokens_in_pg.tolist()
+ self._paligemma_tokenizer.encode("|", add_eos=True)
)
else:
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
return self._fast_tokenizer.decode(
[action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim
)[0]
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
###########################################################################
## The tokenizers below are used for RoboArena baseline implementations. ##
## They are *not* used for pi0-style models. ##
###########################################################################
class BinningTokenizer:
"""
Standard RT-2 / OpenVLA style binning tokenizer.
"""
def __init__(self, max_len: int = 256, n_bins: int = 256):
self._max_len = max_len
self._n_bins = n_bins
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Tokenize a prompt and state into a sequence of tokens.
Args:
prompt: The text prompt to tokenize.
state: The state array to discretize and tokenize.
actions: Must be None. Action encoding is not currently supported.
Returns:
A tuple of (tokens, token_mask, ar_mask, targets).
Raises:
NotImplementedError: If actions is not None.
"""
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
if len(action_tokens) < action_horizon * action_dim:
return np.zeros([action_horizon, action_dim], dtype=np.float32)
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
return action_tokens / self._n_bins * 2 - 1
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
class FSQTokenizer:
"""
FSQ tokenizer from the FAST paper baselines.
"""
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
self._max_len = max_len
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
# Download tokenizer
path = download.maybe_download(fsq_tokenizer_path)
tok_path = os.path.join(path, os.listdir(path)[0])
# Split step from path
step = int(tok_path.split("/")[-1])
base_path = tok_path.rsplit("/", 1)[0]
mgr = ocp.CheckpointManager(
base_path,
item_handlers={
"params": ocp.StandardCheckpointHandler(),
"opt_state": ocp.StandardCheckpointHandler(),
"config": ocp.JsonCheckpointHandler(),
},
options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
try:
restored = mgr.restore(
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
)
config = restored["config"]
self._params = restored["params"]
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
except Exception as e:
raise RuntimeError(
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
) from e
# Compile tokenize and detokenize functions
self._tokenize_fn = jax.jit(
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
)
self._detokenize_fn = jax.jit(
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
)
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
try:
# Move computation to CPU and compile on-demand
device = jax.devices("cpu")[0]
with jax.default_device(device):
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
except Exception as e:
logging.warning(f"Error decoding FSQ: {e}")
return np.zeros((action_horizon, action_dim))
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens

View File

@@ -0,0 +1,27 @@
import numpy as np
from openpi.models import tokenizer as _tokenizer
def test_tokenize():
tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
tokens, masks = tokenizer.tokenize("Hello, world!")
assert tokens.shape == (10,)
assert masks.shape == (10,)
def test_fast_tokenizer():
prompt = "Hello, world!"
state = np.random.rand(5).astype(np.float32)
action = np.random.rand(3, 2).astype(np.float32)
tokenizer = _tokenizer.FASTTokenizer(max_len=256)
tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action)
assert tokens.shape == (256,)
assert token_masks.shape == (256,)
assert ar_masks.shape == (256,)
assert loss_masks.shape == (256,)
act = tokenizer.extract_actions(tokens, 3, 2)
assert act.shape == (3, 2)

View File

@@ -0,0 +1,472 @@
import math
from typing import Any, Literal
import chex
from einops import einops
from flax import linen as nn
from flax.linen.module import Module
from flax.linen.module import compact
from flax.struct import dataclass
from flax.typing import Array
import jax
import jax.numpy as jnp
class FsqCodebook(nn.Module):
input_dim: int
target_codebook_size: int
codebook_type: Literal["fsq", "lfq"]
_bins_per_dim: tuple[int] | None = None
@property
def bins_per_dim(self) -> tuple[int]:
if self._bins_per_dim is not None:
return self._bins_per_dim
if self.codebook_type == "fsq":
return self._get_bins_fsq(self.target_codebook_size)
elif self.codebook_type == "lfq": # noqa: RET505
return self._get_bins_lfq(self.target_codebook_size)
elif self.codebook_type == "custom":
return self._get_bins_custom(self.target_codebook_size)
else:
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
@property
def place_values(self) -> jnp.ndarray:
place_values = [1]
for b in self.bins_per_dim[:-1]:
place_values.append(place_values[-1] * b)
return jnp.array(place_values)
@staticmethod
def _get_bins_fsq(target_codebook_size: int) -> tuple[int]:
"""
Get bins per dimension based on codebook size, from the original FSQ paper.
"""
if target_codebook_size == 2**8:
return (8, 6, 5)
elif target_codebook_size == 2**10: # noqa: RET505
return (8, 5, 5, 5)
elif target_codebook_size == 2**12:
return (7, 5, 5, 5, 5)
elif target_codebook_size == 2**14:
return (8, 8, 8, 6, 5)
elif target_codebook_size == 2**16:
return (8, 8, 8, 5, 5, 5)
else:
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
@staticmethod
def _get_bins_custom(target_codebook_size: int) -> tuple[int]:
if target_codebook_size == 2**8:
return (16, 16)
elif target_codebook_size == 2**10: # noqa: RET505
return (32, 32)
elif target_codebook_size == 2**12:
return (64, 64)
elif target_codebook_size == 2**14:
return (128, 128)
elif target_codebook_size == 2**16:
return (256, 256)
return None
@staticmethod
def _get_bins_lfq(target_codebook_size: int) -> tuple[int]:
"""
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
"""
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
return (2,) * int(math.log2(target_codebook_size))
def setup(self):
self.proj_down = nn.Dense(len(self.bins_per_dim))
self.proj_up = nn.Dense(self.input_dim)
def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
tokens, z = self.encode(inputs)
output = self.decode(tokens, z_grad=z)
return tokens, output
def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
bases = jnp.array(self.bins_per_dim)
x = self.proj_down(inputs)
z = jnp.tanh(x)
# Quantize
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
tokens = self.undigitize(digits)
return tokens, z
def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray:
bases = jnp.array(self.bins_per_dim)
digits = self.digitize(tokens)
z_q = digits / (bases - 1) * 2 - 1
if z_grad is not None:
chex.assert_equal_shape([z_q, z_grad])
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
return self.proj_up(z_q)
def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray:
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
@property
def vocab_size(self) -> int:
return math.prod(self.bins_per_dim)
class ResNetDownBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
skip = x
if self.stride > 1 or x.shape[-1] != self.n_filters:
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
class ResNetUpBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
skip = x
if self.stride > 1:
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
@dataclass
class LfqCodebookOutput:
tokens: jnp.ndarray
z: jnp.ndarray
z_q: jnp.ndarray
token_log_probs: jnp.ndarray
commit_loss: jnp.ndarray
class LookupFreeQuantization(nn.Module):
num_dims: int
latent_dim: int
def setup(self):
self.codebook = jnp.array([-1, 1])
self.activation = nn.tanh
self.project_down = nn.Dense(self.num_dims)
self.project_up = nn.Dense(self.latent_dim)
def encode(self, z: jnp.ndarray) -> jnp.ndarray:
z = self.project_down(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
token_bits = jnp.argmin(token_squared_distances, axis=-1)
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
def decode(self, tokens: jnp.ndarray) -> jnp.ndarray:
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
return self.project_up(self.codebook[token_bits])
def loss(self, x: jnp.ndarray) -> LfqCodebookOutput:
z = self.project_down(x)
z = self.activation(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
tokens = jnp.argmin(token_squared_distances, axis=-1)
token_bit_log_probs = -token_squared_distances
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
token_bit_expansions = jnp.bitwise_and(
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
).astype(jnp.int32)
token_log_probs = (
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
+ token_bit_log_probs[..., 1] @ token_bit_expansions
) # (batch_size, num_tokens, 2 ** num_dims)
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
z_q = self.codebook[tokens]
commit_loss = jnp.square(z - z_q).mean()
z_q = jax.lax.stop_gradient(z_q - z) + z
z_q = self.project_up(z_q)
z = self.project_up(z)
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
return LfqCodebookOutput(
tokens=tokens,
z=z,
z_q=z_q,
token_log_probs=jnp.zeros(()),
commit_loss=commit_loss,
)
def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray:
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
class GeGLU(Module):
"""Gated Linear Unit with GELU (GeGLU) activation function.
GeGLU is a Flax layer that combines a linear transformation with a GELU
activation function in a gating mechanism. It is often used in Transformer models
to provide non-linear capabilities while preserving a strong linear component.
Attributes:
features: the number of output features (default: None).
"""
output_dim: int = -1
@compact
def __call__(self, inputs: Array) -> Array:
"""Applies the GeGLU activation to the inputs.
Args:
inputs: the nd-array to apply the GeGLU activation function to.
Returns:
The transformed input.
"""
output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim
x = nn.Dense(output_dim * 2)(inputs)
x, gate = x[..., :output_dim], x[..., output_dim:]
return x * nn.gelu(gate)
class CrossAttentionLayer(nn.Module):
dropout_rate: float = 0.0
num_heads: int = None
causal: bool = False
mlp_ratio: float = 4.0
@nn.compact
def __call__(
self,
x: jnp.ndarray,
y: jnp.ndarray,
*,
mask_self: jnp.ndarray | None = None,
mask_cross: jnp.ndarray | None = None,
train: bool = True,
) -> jnp.ndarray:
d_embed = x.shape[-1]
seq_len_q = x.shape[-2]
seq_len_k = y.shape[-2]
if self.causal:
# One block size will be 1
bs_q = max(seq_len_q // seq_len_k, 1)
bs_k = max(seq_len_k // seq_len_q, 1)
mask_self = nn.make_causal_mask(x[..., 0])
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
# Self-attention block
skip = x
x = nn.LayerNorm()(x)
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
)(x, x, x, mask=mask_self)
x = skip + x
# Cross-attention block
skip = x
x = nn.LayerNorm()(x)
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
)(x, y, y, mask=mask_cross)
x = skip + x
# MLP block
skip = x
x = nn.LayerNorm()(x)
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = GeGLU()(x)
x = nn.Dense(d_embed)(x)
return skip + x
def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray:
seq_len, d_embed = shape
position = jnp.arange(0, seq_len, 1)
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
return jnp.concatenate(
[
jnp.sin(position[:, jnp.newaxis] * div_term),
jnp.cos(position[:, jnp.newaxis] * div_term),
],
axis=-1,
)
class TokenizerEncoderDecoder(nn.Module):
num_tokens: int
num_cross_tokens: int
num_layers: int
causal: bool
mlp_ratio: float = 4.0
use_state_conditioning: bool = False
@nn.compact
def __call__(
self,
y: jnp.ndarray,
*,
train: bool = True,
state_conditioning: jnp.ndarray | None = None,
mask: jnp.ndarray | None = None,
) -> jnp.ndarray:
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
if mask is not None:
# mask is (batch_dims..., num_cross_tokens)
chex.assert_equal_shape([y[..., 0], mask])
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
else:
attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens))
if self.use_state_conditioning:
assert state_conditioning is not None, "State conditioning is required for this model."
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
y = jnp.concatenate([y, state_embed], axis=-2)
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
for _ in range(self.num_layers):
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
x, y, train=train, mask_self=None, mask_cross=attn_mask
)
return x
class FsqAttentionTokenizer(nn.Module):
embed_dim: int
data_dim: int
data_horizon: int
num_tokens: int
num_layers: int
target_codebook_size: int
causal: bool = False
mlp_ratio: float = 2.0
bound: float | None = None
use_state_conditioning: bool = False
@property
def vocab_size(self) -> int:
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001
def setup(self):
self.proj = nn.Dense(self.embed_dim)
self.encoder = TokenizerEncoderDecoder(
num_tokens=self.num_tokens,
num_cross_tokens=self.data_horizon,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.codebook = FsqCodebook(
input_dim=self.embed_dim,
target_codebook_size=self.target_codebook_size,
codebook_type="custom",
)
self.decoder = TokenizerEncoderDecoder(
num_tokens=self.data_horizon,
num_cross_tokens=self.num_tokens,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.proj_mean = nn.Dense(self.data_dim)
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
def tokenize(
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False
) -> tuple[jnp.ndarray, jnp.ndarray]:
if self.bound is not None:
action = jnp.clip(action, -self.bound, self.bound)
x = self.proj(action)
x = self.encoder(x, train=train, state_conditioning=obs)
return self.codebook.encode(x)
def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray:
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
mean = self.proj_mean(x)
return mean * self.out_scale
def loss(
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True
) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
# Encode
x = self.proj(action)
z = self.encoder(x, train=train, state_conditioning=obs)
# Quantize
tokens, z = self.codebook(z)
# Decode
x = self.decoder(z, train=train, state_conditioning=obs)
mean = self.proj_mean(x) * self.out_scale
mse = jnp.mean(jnp.square(action - mean))
mae = jnp.mean(jnp.abs(action - mean))
return mse, {
"mse": mse,
"mae": mae,
}
def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
"""
Dummy for .init
"""
return self.loss(*args, **kwargs)

View File

@@ -0,0 +1,307 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
from collections.abc import Callable
from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
from openpi.models import resnet as models_resnet
Array = Any
PRNGKey = Any
Shape = tuple[int]
Dtype = Any
class IdentityLayer(nn.Module):
"""Identity layer, convenient for giving a name to an array."""
@nn.compact
def __call__(self, x):
return x
class AddPositionEmbs(nn.Module):
"""Adds learned positional embeddings to the inputs.
Attributes:
posemb_init: positional embedding initializer.
"""
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, inputs):
"""Applies the AddPositionEmbs module.
Args:
inputs: Inputs to the layer.
Returns:
Output tensor with shape `(bs, timesteps, in_dim)`.
"""
# inputs.shape is (batch_size, seq_len, emb_dim).
assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}"
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype)
return inputs + pe
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
out_dim: int | None = None
dropout_rate: float = 0.1
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
@nn.compact
def __call__(self, inputs, *, deterministic):
"""Applies Transformer MlpBlock module."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)( # pytype: disable=wrong-arg-types
inputs
)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
output = nn.Dense(
features=actual_out_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)( # pytype: disable=wrong-arg-types
x
)
return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
Attributes:
inputs: input data.
mlp_dim: dimension of the mlp on top of attention block.
dtype: the dtype of the computation (default: float32).
dropout_rate: dropout rate.
attention_dropout_rate: dropout for attention heads.
deterministic: bool, deterministic or not (to apply dropout).
num_heads: Number of heads in nn.MultiHeadDotProductAttention
"""
mlp_dim: int
num_heads: int
dtype: Dtype = jnp.float32
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
@nn.compact
def __call__(self, inputs, deterministic):
"""Applies Encoder1DBlock module.
Args:
inputs: Inputs to the layer.
deterministic: Dropout will not be applied when set to true.
Returns:
output after transformer encoder block.
"""
# Attention block.
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.MultiHeadDotProductAttention(
dtype=self.dtype,
kernel_init=nn.initializers.xavier_uniform(),
broadcast_dropout=False,
deterministic=deterministic,
dropout_rate=self.attention_dropout_rate,
num_heads=self.num_heads,
# why isn't this true by default???
force_fp32_for_softmax=True,
)(x, x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = x + inputs
# MLP block.
y = nn.LayerNorm(dtype=self.dtype)(x)
y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
y, deterministic=deterministic
)
return x + y, None
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation.
Attributes:
num_layers: number of layers
mlp_dim: dimension of the mlp on top of attention block
num_heads: Number of heads in nn.MultiHeadDotProductAttention
dropout_rate: dropout rate.
attention_dropout_rate: dropout rate in self attention.
"""
dtype: jax.typing.DTypeLike
num_layers: int
mlp_dim: int
num_heads: int
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
add_position_embedding: bool = True
@nn.compact
def __call__(self, x, *, train):
"""Applies Transformer model on the inputs.
Args:
x: Inputs to the layer.
train: Set to `True` when training.
Returns:
output of a transformer encoder.
"""
assert x.ndim == 3 # (batch, len, emb)
if self.add_position_embedding:
x = AddPositionEmbs(
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
name="posembed_input",
)(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
x = x.astype(self.dtype)
# Input Encoder
block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))
x, _ = nn.scan(
block,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=nn.broadcast,
length=self.num_layers,
)(
name="encoderblock",
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
dtype=self.dtype,
num_heads=self.num_heads,
)(x, not train)
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
class VisionTransformer(nn.Module):
"""VisionTransformer."""
dtype: jax.typing.DTypeLike
num_classes: int
patches: Any
transformer: Any
hidden_size: int
resnet: Any | None = None
representation_size: int | None = None
classifier: str = "token"
head_bias_init: float = 0.0
encoder: type[nn.Module] = Encoder
model_name: str | None = None
@nn.compact
def __call__(self, inputs, *, train):
x = inputs
# (Possibly partial) ResNet root.
if self.resnet is not None:
width = int(64 * self.resnet.width_factor)
# Root block.
x = models_resnet.StdConv(
features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root"
)(x)
x = nn.GroupNorm(name="gn_root")(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
# ResNet stages.
if self.resnet.num_layers:
x = models_resnet.ResNetStage(
block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1"
)(x)
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
x = models_resnet.ResNetStage(
block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
)(x)
n, h, w, c = x.shape
# We can merge s2d+emb into a single conv; it's the same.
x = nn.Conv(
features=self.hidden_size,
kernel_size=self.patches.size,
strides=self.patches.size,
padding="VALID",
name="embedding",
)(x)
# Here, x is a grid of embeddings.
# (Possibly partial) Transformer.
if self.transformer is not None:
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# If we want to add a class token, add it here.
if self.classifier in ["token", "token_unpooled"]:
cls = self.param("cls", nn.initializers.zeros, (1, 1, c))
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1)
x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train)
if self.classifier == "token":
x = x[:, 0]
elif self.classifier == "gap":
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
elif self.classifier in ["unpooled", "token_unpooled"]:
pass
else:
raise ValueError(f"Invalid classifier={self.classifier}")
if self.representation_size is not None:
x = nn.Dense(features=self.representation_size, name="pre_logits")(x)
x = nn.tanh(x)
else:
x = IdentityLayer(name="pre_logits")(x)
if self.num_classes:
x = nn.Dense(
features=self.num_classes,
name="head",
kernel_init=nn.initializers.zeros,
bias_init=nn.initializers.constant(self.head_bias_init),
)(x)
return x

View File

@@ -0,0 +1,281 @@
from typing import Literal
import pytest
import torch
from torch import nn
from transformers import GemmaForCausalLM
from transformers import PaliGemmaForConditionalGeneration
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
class PaliGemmaWithExpertModel(nn.Module):
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.torch_dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
torch_dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
params_to_keep_float32 = [
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Force enable gradient checkpointing if we're in training mode and the model supports it
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
if not self.gemma_expert.model.gradient_checkpointing:
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
self.gemma_expert.model.gradient_checkpointing = True
use_gradient_checkpointing = True
# Debug gradient checkpointing status
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
print(f"Model training mode: {self.training}")
print(
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
)
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
print(
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
)
self._debug_gc_printed = True
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
models = [self.paligemma.language_model, self.gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
self.paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
)
# Old code removed - now using compute_layer_complete function above
# final norm
# Define final norm computation function for gradient checkpointing
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values

View File

@@ -0,0 +1,461 @@
import logging
import math
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F # noqa: N812
import openpi.models.gemma as _gemma
from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
import openpi.models_pytorch.preprocessing_pytorch as _preprocessing
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device):
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
return att_2d_masks & pad_2d_masks
class PI0Pytorch(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pi05 = config.pi05
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True] if self.pi05 else [False, False],
precision=config.dtype,
)
self.action_in_proj = nn.Linear(32, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
if self.pi05:
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
else:
self.state_proj = nn.Linear(32, action_expert_config.width)
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* /home/tianyang/miniconda3/envs/lam3d/lib/python3.11/site-packages/transformers/`."
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def is_gradient_checkpointing_enabled(self):
"""Check if gradient checkpointing is enabled."""
return self.gradient_checkpointing_enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
def _preprocess_observation(self, observation, *, train=True):
"""Helper method to preprocess observation."""
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
return (
list(observation.images.values()),
list(observation.image_masks.values()),
observation.tokenized_prompt,
observation.tokenized_prompt_mask,
observation.state,
)
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
# Get batch size from the first dimension of the concatenated tensors
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
if not self.pi05:
if self.state_proj.weight.dtype == torch.float32:
state = state.to(torch.float32)
# Embed state
def state_proj_func(state):
return self.state_proj(state)
state_emb = self._apply_checkpoint(state_proj_func, state)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device
)
time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
def action_proj_func(noisy_actions):
return self.action_in_proj(noisy_actions)
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
if not self.pi05:
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
# Apply MLP layers
def mlp_func(action_time_emb):
x = self.action_time_mlp_in(action_time_emb)
x = F.silu(x) # swish == silu
return self.action_time_mlp_out(x)
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
adarms_cond = None
else:
# time MLP (for adaRMS)
def time_mlp_func(time_emb):
x = self.time_mlp_in(time_emb)
x = F.silu(x) # swish == silu
x = self.time_mlp_out(x)
return F.silu(x)
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
action_time_emb = action_emb
adarms_cond = time_emb
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, adarms_cond
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True)
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
# Prepare attention masks
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
# Apply gradient checkpointing if enabled
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
# Apply gradient checkpointing to final action projection if enabled
def action_out_proj_func(suffix_out):
return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad()
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = observation.state.shape[0]
if noise is None:
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
noise = self.sample_noise(actions_shape, device)
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step - use new tensor assignment instead of in-place operation
x_t = x_t + dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Prepare attention masks
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)

View File

@@ -0,0 +1,173 @@
from collections.abc import Sequence
import logging
import torch
from openpi.shared import image_tools
logger = logging.getLogger("openpi")
# Constants moved from model.py
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
IMAGE_RESOLUTION = (224, 224)
def preprocess_observation_pytorch(
observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
):
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
This function avoids complex type annotations that can cause torch.compile issues.
"""
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
# Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad_torch(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
image = image / 2.0 + 0.5
# Apply PyTorch-based augmentations
if "wrist" not in key:
# Geometric augmentations for non-wrist cameras
height, width = image.shape[1:3]
# Random crop and resize
crop_height = int(height * 0.95)
crop_width = int(width * 0.95)
# Random crop
max_h = height - crop_height
max_w = width - crop_width
if max_h > 0 and max_w > 0:
# Use tensor operations instead of .item() for torch.compile compatibility
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
# Resize back to original size
image = torch.nn.functional.interpolate(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
size=(height, width),
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Random rotation (small angles)
# Use tensor operations instead of .item() for torch.compile compatibility
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
# Convert to radians
angle_rad = angle * torch.pi / 180.0
# Create rotation matrix
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# Apply rotation using grid_sample
grid_x = torch.linspace(-1, 1, width, device=image.device)
grid_y = torch.linspace(-1, 1, height, device=image.device)
# Create meshgrid
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
# Expand to batch dimension
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
# Apply rotation transformation
grid_x_rot = grid_x * cos_a - grid_y * sin_a
grid_y_rot = grid_x * sin_a + grid_y * cos_a
# Stack and reshape for grid_sample
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
image = torch.nn.functional.grid_sample(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
grid,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Color augmentations for all cameras
# Random brightness
# Use tensor operations instead of .item() for torch.compile compatibility
brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3
image = image * brightness_factor
# Random contrast
# Use tensor operations instead of .item() for torch.compile compatibility
contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4
mean = image.mean(dim=[1, 2, 3], keepdim=True)
image = (image - mean) * contrast_factor + mean
# Random saturation (convert to HSV, modify S, convert back)
# For simplicity, we'll just apply a random scaling to the color channels
# Use tensor operations instead of .item() for torch.compile compatibility
saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5
gray = image.mean(dim=-1, keepdim=True)
image = gray + (image - gray) * saturation_factor
# Clamp values to [0, 1]
image = torch.clamp(image, 0, 1)
# Back to [-1, 1]
image = image * 2.0 - 1.0
# Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
else:
out_masks[key] = observation.image_masks[key]
# Create a simple object with the required attributes instead of using the complex Observation class
class SimpleProcessedObservation:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return SimpleProcessedObservation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
token_ar_mask=observation.token_ar_mask,
token_loss_mask=observation.token_loss_mask,
)

View File

@@ -0,0 +1,173 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...configuration_utils import PretrainedConfig
class GemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
hidden_activation (`str` or `function`, *optional*):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_adarms (`bool`, *optional*, defaults to `False`):
Whether to use ADARMS.
adarms_cond_dim (`int`, *optional*, defaults to `None`):
The dimension of the ADARMS condition.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
use_adarms: bool = False,
adarms_cond_dim: Optional[int] = None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_adarms = use_adarms
self.adarms_cond_dim = adarms_cond_dim
# Set default for adarms_cond_dim if use_adarms is True
if self.use_adarms and self.adarms_cond_dim is None:
self.adarms_cond_dim = self.hidden_size
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["GemmaConfig"]

View File

@@ -0,0 +1,862 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):
super().__init__()
self.eps = eps
self.dim = dim
self.cond_dim = cond_dim
# Dense layer for adaptive normalization (if cond_dim is provided)
if cond_dim is not None:
#self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
# Initialize with zeros (matches source implementation)
nn.init.zeros_(self.dense.weight)
else:
self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
self.dense = None
def _norm(self, x):
# Compute variance in float32 (like the source implementation)
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
# Compute normalization in float32
normed_inputs = x * torch.rsqrt(var + self.eps)
return normed_inputs
def forward(self, x, cond=None):
dtype = x.dtype # original dtype, could be half-precision
normed_inputs = self._norm(x)
if cond is None or self.dense is None:
# regular RMSNorm
# scale by learned parameter in float32 (matches source implementation)
normed_inputs = normed_inputs * (1.0 + self.weight.float())
return normed_inputs.to(dtype), None # return in original dtype with None gate
# adaptive RMSNorm (if cond is provided and dense layer exists)
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
#self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
modulation = self.dense(cond)
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
if len(x.shape) == 3: # [batch, seq, features]
modulation = modulation.unsqueeze(1)
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
# Apply adaptive normalization: use model weight dtype to ensure compatibility
# model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
# scale = scale.to(model_dtype)
# shift = shift.to(model_dtype)
# gate = gate.to(model_dtype)
# normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
return normed_inputs.to(dtype), gate.to(dtype)
def extra_repr(self):
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
if self.dense is not None:
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
return repr_str
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, config: GemmaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _gated_residual(x, y, gate):
"""
Applies gated residual connection with optional gate parameter.
Args:
x: Input tensor (residual)
y: Output tensor to be added
gate: Optional gate tensor to modulate the addition
Returns:
x + y if gate is None, otherwise x + y * gate
"""
if x is None and y is None:
return None
if x is None or y is None:
return x if x is not None else y
if gate is None:
return x + y
return x + y * gate
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Use cache if provided
if past_key_value is not None:
if use_cache:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
else:
key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class GemmaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
adarms_cond: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = _gated_residual(residual, hidden_states, gate)
# Fully Connected
residual = hidden_states
hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
hidden_states = self.mlp(hidden_states)
hidden_states = _gated_residual(residual, hidden_states, gate)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
@auto_docstring
class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GemmaRMSNorm):
if hasattr(module, 'weight'):
module.weight.data.fill_(1.0)
@auto_docstring
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
self.rotary_emb = GemmaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
adarms_cond: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
"""
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
#hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states, _ = self.norm(hidden_states, adarms_cond)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
adarms_cond: Optional[torch.Tensor] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@auto_docstring(
custom_intro="""
The Gemma Model transformer with a sequence classification head on top (linear layer).
[`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
)
class GemmaForSequenceClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
adarms_cond: Optional[torch.Tensor] = None,
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
adarms_cond=adarms_cond,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@auto_docstring
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
adarms_cond: Optional[torch.Tensor] = None,
) -> TokenClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
adarms_cond=adarms_cond,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaPreTrainedModel",
]

View File

@@ -0,0 +1,622 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch PaliGemmamodel."""
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ..auto import AutoModel
from .configuration_paligemma import PaliGemmaConfig
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
Base class for Paligemma outputs, with hidden states and attentions.
"""
)
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for PaliGemma causal language model (or autoregressive) outputs.
"""
)
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
def forward(self, image_features):
hidden_states = self.linear(image_features)
return hidden_states
@auto_docstring
class PaliGemmaPreTrainedModel(PreTrainedModel):
config_class = PaliGemmaConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
# inference and fine-tuning
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
@auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.,
"""
)
class PaliGemmaModel(PaliGemmaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def _update_causal_mask(
self,
attention_mask,
token_type_ids=None,
past_key_values=None,
cache_position=None,
input_tensor=None,
is_training: Optional[bool] = None,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
is_training = is_training if is_training is not None else self.training
using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(self.dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[0] + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
return attention_mask
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask[:, :sequence_length] = 0.0
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# First unmask prefix tokens during training
if is_training:
if token_type_ids is None:
raise ValueError("Token type ids must be provided during training")
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, PaligemmaModelOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return PaligemmaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring(
custom_intro="""
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
"""
)
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: PaliGemmaConfig):
super().__init__(config)
self.model = PaliGemmaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return PaliGemmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
logits_to_keep=None,
labels=None,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
# position_ids in Paligemma are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
@staticmethod
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]

View File

@@ -0,0 +1,4 @@
import transformers
def check_whether_transformers_replace_is_installed_correctly():
return transformers.__version__ == "4.53.2"

View File

@@ -0,0 +1,202 @@
import dataclasses
from typing import ClassVar
import einops
import numpy as np
from openpi import transforms
def make_aloha_example() -> dict:
"""Creates a random input example for the Aloha policy."""
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
@dataclasses.dataclass(frozen=True)
class AlohaInputs(transforms.DataTransformFn):
"""Inputs for the Aloha policy.
Expected inputs:
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
- state: [14]
- actions: [action_horizon, 14]
"""
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": data["state"],
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = actions
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class AlohaOutputs(transforms.DataTransformFn):
"""Outputs for the Aloha policy."""
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi: bool = True
def __call__(self, data: dict) -> dict:
# Only return the first 14 dims.
actions = np.asarray(data["actions"][:, :14])
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
def _joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def _normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def _unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def _gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
# There are 4096 total encoder counts and aloha uses a zero of 2048.
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
return _normalize(value, min_val=0.5476, max_val=1.6296)
def _gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# We do not scale the output since the trossen model predictions are already in radians.
# See the comment in _gripper_to_angular for a derivation of the constant
value = value + 0.5476
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return _normalize(value, min_val=-0.6213, max_val=1.4910)
def _gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
return value - 0.5476
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
# state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
# dim sizes: [6, 1, 6, 1]
state = np.asarray(data["state"])
state = _decode_state(state, adapt_to_pi=adapt_to_pi)
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c")
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
data["state"] = state
return data
def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
# Flip the joints.
state = _joint_flip_mask() * state
# Reverse the gripper transformation that is being applied by the Aloha runtime.
state[[6, 13]] = _gripper_to_angular(state[[6, 13]])
return state
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
# Flip the joints.
actions = _joint_flip_mask() * actions
actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]])
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
actions = _joint_flip_mask() * actions
actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]])
return actions

View File

@@ -0,0 +1,81 @@
import dataclasses
import einops
import numpy as np
from openpi import transforms
from openpi.models import model as _model
def make_droid_example() -> dict:
"""Creates a random input example for the Droid policy."""
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _parse_image(image) -> np.ndarray:
image = np.asarray(image)
if np.issubdtype(image.dtype, np.floating):
image = (255 * image).astype(np.uint8)
if image.shape[0] == 3:
image = einops.rearrange(image, "c h w -> h w c")
return image
@dataclasses.dataclass(frozen=True)
class DroidInputs(transforms.DataTransformFn):
# Determines which model will be used.
model_type: _model.ModelType
def __call__(self, data: dict) -> dict:
gripper_pos = np.asarray(data["observation/gripper_position"])
if gripper_pos.ndim == 0:
# Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions
gripper_pos = gripper_pos[np.newaxis]
state = np.concatenate([data["observation/joint_position"], gripper_pos])
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference
base_image = _parse_image(data["observation/exterior_image_1_left"])
wrist_image = _parse_image(data["observation/wrist_image_left"])
match self.model_type:
case _model.ModelType.PI0 | _model.ModelType.PI05:
names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
images = (base_image, wrist_image, np.zeros_like(base_image))
image_masks = (np.True_, np.True_, np.False_)
case _model.ModelType.PI0_FAST:
names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb")
# We don't mask out padding images for FAST models.
images = (base_image, np.zeros_like(base_image), wrist_image)
image_masks = (np.True_, np.True_, np.True_)
case _:
raise ValueError(f"Unsupported model type: {self.model_type}")
inputs = {
"state": state,
"image": dict(zip(names, images, strict=True)),
"image_mask": dict(zip(names, image_masks, strict=True)),
}
if "actions" in data:
inputs["actions"] = np.asarray(data["actions"])
if "prompt" in data:
if isinstance(data["prompt"], bytes):
data["prompt"] = data["prompt"].decode("utf-8")
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class DroidOutputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# Only return the first 8 dims.
return {"actions": np.asarray(data["actions"][:, :8])}

View File

@@ -0,0 +1,100 @@
import dataclasses
import einops
import numpy as np
from openpi import transforms
from openpi.models import model as _model
def make_libero_example() -> dict:
"""Creates a random input example for the Libero policy."""
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
def _parse_image(image) -> np.ndarray:
image = np.asarray(image)
if np.issubdtype(image.dtype, np.floating):
image = (255 * image).astype(np.uint8)
if image.shape[0] == 3:
image = einops.rearrange(image, "c h w -> h w c")
return image
@dataclasses.dataclass(frozen=True)
class LiberoInputs(transforms.DataTransformFn):
"""
This class is used to convert inputs to the model to the expected format. It is used for both training and inference.
For your own dataset, you can copy this class and modify the keys based on the comments below to pipe
the correct elements of your dataset into the model.
"""
# Determines which model will be used.
# Do not change this for your own dataset.
model_type: _model.ModelType
def __call__(self, data: dict) -> dict:
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference.
# Keep this for your own dataset, but if your dataset stores the images
# in a different key than "observation/image" or "observation/wrist_image",
# you should change it below.
# Pi0 models support three image inputs at the moment: one third-person view,
# and two wrist views (left and right). If your dataset does not have a particular type
# of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the
# right wrist image below.
base_image = _parse_image(data["observation/image"])
wrist_image = _parse_image(data["observation/wrist_image"])
# Create inputs dict. Do not change the keys in the dict below.
inputs = {
"state": data["observation/state"],
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# Pad any non-existent images with zero-arrays of the appropriate shape.
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
# We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset.
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
},
}
# Pad actions to the model action dimension. Keep this for your own dataset.
# Actions are only available during training.
if "actions" in data:
inputs["actions"] = data["actions"]
# Pass the prompt (aka language instruction) to the model.
# Keep this for your own dataset (but modify the key if the instruction is not
# stored in "prompt"; the output dict always needs to have the key "prompt").
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class LiberoOutputs(transforms.DataTransformFn):
"""
This class is used to convert outputs from the model back the the dataset specific format. It is
used for inference only.
For your own dataset, you can copy this class and modify the action dimension based on the comments below.
"""
def __call__(self, data: dict) -> dict:
# Only return the first N actions -- since we padded actions above to fit the model action
# dimension, we need to now parse out the correct number of actions in the return dict.
# For Libero, we only return the first 7 actions (since the rest is padding).
# For your own dataset, replace `7` with the action dimension of your dataset.
return {"actions": np.asarray(data["actions"][:, :7])}

View File

@@ -0,0 +1,135 @@
from collections.abc import Sequence
import logging
import pathlib
import time
from typing import Any, TypeAlias
import flax
import flax.traverse_util
import jax
import jax.numpy as jnp
import numpy as np
from openpi_client import base_policy as _base_policy
import torch
from typing_extensions import override
from openpi import transforms as _transforms
from openpi.models import model as _model
from openpi.shared import array_typing as at
from openpi.shared import nnx_utils
BasePolicy: TypeAlias = _base_policy.BasePolicy
class Policy(BasePolicy):
def __init__(
self,
model: _model.BaseModel,
*,
rng: at.KeyArrayLike | None = None,
transforms: Sequence[_transforms.DataTransformFn] = (),
output_transforms: Sequence[_transforms.DataTransformFn] = (),
sample_kwargs: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
pytorch_device: str = "cpu",
is_pytorch: bool = False,
):
"""Initialize the Policy.
Args:
model: The model to use for action sampling.
rng: Random number generator key for JAX models. Ignored for PyTorch models.
transforms: Input data transformations to apply before inference.
output_transforms: Output data transformations to apply after inference.
sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
metadata: Additional metadata to store with the policy.
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
Only relevant when is_pytorch=True.
is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
"""
self._model = model
self._input_transform = _transforms.compose(transforms)
self._output_transform = _transforms.compose(output_transforms)
self._sample_kwargs = sample_kwargs or {}
self._metadata = metadata or {}
self._is_pytorch_model = is_pytorch
self._pytorch_device = pytorch_device
if self._is_pytorch_model:
self._model = self._model.to(pytorch_device)
self._model.eval()
self._sample_actions = model.sample_actions
else:
# JAX model setup
self._sample_actions = nnx_utils.module_jit(model.sample_actions)
self._rng = rng or jax.random.key(0)
@override
def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: ignore[misc]
# Make a copy since transformations may modify the inputs in place.
inputs = jax.tree.map(lambda x: x, obs)
inputs = self._input_transform(inputs)
if not self._is_pytorch_model:
# Make a batch and convert to jax.Array.
inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)
self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng)
else:
# Convert inputs to PyTorch tensors and move to correct device
inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs)
sample_rng_or_pytorch_device = self._pytorch_device
# Prepare kwargs for sample_actions
sample_kwargs = dict(self._sample_kwargs)
if noise is not None:
noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)
if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension
noise = noise[None, ...] # Make it (1, action_horizon, action_dim)
sample_kwargs["noise"] = noise
observation = _model.Observation.from_dict(inputs)
start_time = time.monotonic()
outputs = {
"state": inputs["state"],
"actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs),
}
model_time = time.monotonic() - start_time
if self._is_pytorch_model:
outputs = jax.tree.map(lambda x: np.asarray(x[0, ...].detach().cpu()), outputs)
else:
outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs)
outputs = self._output_transform(outputs)
outputs["policy_timing"] = {
"infer_ms": model_time * 1000,
}
return outputs
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
class PolicyRecorder(_base_policy.BasePolicy):
"""Records the policy's behavior to disk."""
def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
self._policy = policy
logging.info(f"Dumping policy records to: {record_dir}")
self._record_dir = pathlib.Path(record_dir)
self._record_dir.mkdir(parents=True, exist_ok=True)
self._record_step = 0
@override
def infer(self, obs: dict) -> dict: # type: ignore[misc]
results = self._policy.infer(obs)
data = {"inputs": obs, "outputs": results}
data = flax.traverse_util.flatten_dict(data, sep="/")
output_path = self._record_dir / f"step_{self._record_step}"
self._record_step += 1
np.save(output_path, np.asarray(data))
return results

View File

@@ -0,0 +1,94 @@
import logging
import os
import pathlib
from typing import Any
import jax.numpy as jnp
import openpi.models.model as _model
import openpi.policies.policy as _policy
import openpi.shared.download as download
from openpi.training import checkpoints as _checkpoints
from openpi.training import config as _config
import openpi.transforms as transforms
def create_trained_policy(
train_config: _config.TrainConfig,
checkpoint_dir: pathlib.Path | str,
*,
repack_transforms: transforms.Group | None = None,
sample_kwargs: dict[str, Any] | None = None,
default_prompt: str | None = None,
norm_stats: dict[str, transforms.NormStats] | None = None,
pytorch_device: str | None = None,
) -> _policy.Policy:
"""Create a policy from a trained checkpoint.
Args:
train_config: The training config to use to create the model.
checkpoint_dir: The directory to load the model from.
repack_transforms: Optional transforms that will be applied before any other transforms.
sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
kwargs will be used.
default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
data if it doesn't already exist.
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
from the checkpoint directory.
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
Note:
The function automatically detects whether the model is PyTorch-based by checking for the
presence of "model.safensors" in the checkpoint directory.
"""
repack_transforms = repack_transforms or transforms.Group()
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
# Check if this is a PyTorch model by looking for model.safetensors
weight_path = os.path.join(checkpoint_dir, "model.safetensors")
is_pytorch = os.path.exists(weight_path)
logging.info("Loading model...")
if is_pytorch:
model = train_config.model.load_pytorch(train_config, weight_path)
model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
else:
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
if norm_stats is None:
# We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
# that the policy is using the same normalization stats as the original training process.
if data_config.asset_id is None:
raise ValueError("Asset id is required to load norm stats.")
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
# Determine the device to use for PyTorch models
if is_pytorch and pytorch_device is None:
try:
import torch
pytorch_device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
pytorch_device = "cpu"
return _policy.Policy(
model,
transforms=[
*repack_transforms.inputs,
transforms.InjectDefaultPrompt(default_prompt),
*data_config.data_transforms.inputs,
transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
output_transforms=[
*data_config.model_transforms.outputs,
transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.data_transforms.outputs,
*repack_transforms.outputs,
],
sample_kwargs=sample_kwargs,
metadata=train_config.policy_metadata,
is_pytorch=is_pytorch,
pytorch_device=pytorch_device if is_pytorch else None,
)

View File

@@ -0,0 +1,34 @@
from openpi_client import action_chunk_broker
import pytest
from openpi.policies import aloha_policy
from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config
@pytest.mark.manual
def test_infer():
config = _config.get_config("pi0_aloha_sim")
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
example = aloha_policy.make_aloha_example()
result = policy.infer(example)
assert result["actions"].shape == (config.model.action_horizon, 14)
@pytest.mark.manual
def test_broker():
config = _config.get_config("pi0_aloha_sim")
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
broker = action_chunk_broker.ActionChunkBroker(
policy,
# Only execute the first half of the chunk.
action_horizon=config.model.action_horizon // 2,
)
example = aloha_policy.make_aloha_example()
for _ in range(config.model.action_horizon):
outputs = broker.infer(example)
assert outputs["actions"].shape == (14,)

View File

@@ -0,0 +1,221 @@
import dataclasses
from typing import ClassVar
import einops
import numpy as np
from openpi import transforms
from pdb import set_trace
@dataclasses.dataclass(frozen=True)
class Reala2dInputs(transforms.DataTransformFn):
"""Inputs for the A2D policy.
"""
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_a2d(data, adapt_to_pi=self.adapt_to_pi)
if "images" in data:
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": data["state"],
}
else:
inputs={
"state": data["state"],
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = actions
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class Reala2dOutputs(transforms.DataTransformFn):
"""Outputs for the a2d policy."""
adapt_to_pi: bool = True
def __call__(self, data: dict) -> dict:
# Only return the first 16 dims.
actions = np.asarray(data["actions"][:, :16])
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
def _joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def _normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def _unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def _gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
# There are 4096 total encoder counts and aloha uses a zero of 2048.
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
return _normalize(value, min_val=0.5476, max_val=1.6296)
def _gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# We do not scale the output since the trossen model predictions are already in radians.
# See the comment in _gripper_to_angular for a derivation of the constant
value = value + 0.5476
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return _normalize(value, min_val=-0.6213, max_val=1.4910)
def _gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
return value - 0.5476
def _decode_a2d(data: dict, *, adapt_to_pi: bool = False) -> dict:
state_dict = data["state_dict"]
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
del data["state_dict"]
action_dict = data["action_dict"]
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
del data["action_dict"]
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c")
if "images" in data:
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
return data
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
joint = state["joint"]
gripper = state["gripper"]
state_left_arm_gripper = np.concatenate(
[
joint[:7],
gripper[:1],
],
axis=-1
)
state_right_arm_gripper = np.concatenate(
[
joint[7:14],
gripper[1:2],
],
axis=-1
)
state = np.concatenate(
[
state_left_arm_gripper,
state_right_arm_gripper,
],
axis=-1
)
return state
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
joint = action["joint"]
gripper = action["gripper"]
action_left_arm_gripper = np.concatenate(
[
joint[:,:7],
gripper[:,:1],
],
axis=-1
)
action_right_arm_gripper = np.concatenate(
[
joint[:,7:14],
gripper[:,1:2],
],
axis=-1
)
action = np.concatenate(
[
action_left_arm_gripper,
action_right_arm_gripper,
],
axis=-1
)
return action
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions

View File

@@ -0,0 +1,205 @@
import dataclasses
from typing import ClassVar
import einops
import numpy as np
from openpi import transforms
from pdb import set_trace
@dataclasses.dataclass(frozen=True)
class RealLift2Inputs(transforms.DataTransformFn):
"""Inputs for the Lift2 policy.
"""
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
if "images" in data:
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": data["state"],
}
else:
inputs={
"state": data["state"],
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = actions
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class RealLift2Outputs(transforms.DataTransformFn):
"""Outputs for the Lift2 policy."""
adapt_to_pi: bool = True
def __call__(self, data: dict) -> dict:
# Only return the first 14 dims.
actions = np.asarray(data["actions"][:, :14])
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
def _joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def _normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def _unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def _gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
# There are 4096 total encoder counts and aloha uses a zero of 2048.
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
return _normalize(value, min_val=0.5476, max_val=1.6296)
def _gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# We do not scale the output since the trossen model predictions are already in radians.
# See the comment in _gripper_to_angular for a derivation of the constant
value = value + 0.5476
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return _normalize(value, min_val=-0.6213, max_val=1.4910)
def _gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
return value - 0.5476
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
# state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
# dim sizes: [6, 1, 6, 1]
state_dict = data["state_dict"]
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
del data["state_dict"]
action_dict = data["action_dict"]
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
del data["action_dict"]
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c")
if "images" in data:
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
return data
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
state_left_arm = state["left_joint"]
state_left_gripper = state["left_gripper"]
state_right_arm = state["right_joint"]
state_right_gripper = state["right_gripper"]
if state_left_arm.ndim - state_left_gripper.ndim == 1:
if state_left_gripper.ndim == 0:
state_left_gripper = state_left_gripper[None]
state_right_gripper = state_right_gripper[None]
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=0)
elif state_left_gripper.ndim == 1:
state_left_gripper = state_left_gripper[:, None]
state_right_gripper = state_right_gripper[:, None]
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=1)
return state
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
action_left_arm = action["left_joint"]
action_left_gripper = action["left_gripper"]
action_right_arm = action["right_joint"]
action_right_gripper = action["right_gripper"]
if action_left_arm.ndim - action_left_gripper.ndim == 1:
if action_left_gripper.ndim == 0:
action_left_gripper = action_left_gripper[None]
action_right_gripper = action_right_gripper[None]
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=0)
elif action_left_gripper.ndim == 1:
action_left_gripper = action_left_gripper[:, None]
action_right_gripper = action_right_gripper[:, None]
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=1)
return action
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions

View File

@@ -0,0 +1,207 @@
import dataclasses
from typing import ClassVar
import einops
import numpy as np
from openpi import transforms
from pdb import set_trace
@dataclasses.dataclass(frozen=True)
class Sim2RealSplitAlohaInputs(transforms.DataTransformFn):
"""Inputs for the Split Aloha policy.
"""
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
if "images" in data:
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": data["state"],
}
else:
inputs = {
"state": data["state"],
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = actions
if "prompt" in data:
inputs["prompt"] = data["prompt"]
# set_trace()
return inputs
@dataclasses.dataclass(frozen=True)
class Sim2RealSplitAlohaOutputs(transforms.DataTransformFn):
"""Outputs for the Split Aloha policy."""
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi: bool = True
def __call__(self, data: dict) -> dict:
# Only return the first 14 dims.
actions = np.asarray(data["actions"][:, :14])
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
def _joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def _normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def _unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def _gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
# There are 4096 total encoder counts and aloha uses a zero of 2048.
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
return _normalize(value, min_val=0.5476, max_val=1.6296)
def _gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# We do not scale the output since the trossen model predictions are already in radians.
# See the comment in _gripper_to_angular for a derivation of the constant
value = value + 0.5476
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return _normalize(value, min_val=-0.6213, max_val=1.4910)
def _gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
return value - 0.5476
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
state_dict = data["state_dict"]
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
del data["state_dict"]
action_dict = data["action_dict"]
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
del data["action_dict"]
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c")
if "images" in data:
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
return data
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
state_left_arm = state["left_joint"]
state_left_gripper = state["left_gripper"]
state_right_arm = state["right_joint"]
state_right_gripper = state["right_gripper"]
if state_left_arm.ndim - state_left_gripper.ndim == 1:
if state_left_gripper.ndim == 0:
state_left_gripper = np.array([0])
state_right_gripper = np.array([0])
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=0)
elif state_left_gripper.ndim == 1:
state_left_gripper = np.array([[0]])
state_right_gripper = np.array([[0]])
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=1)
return state
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
action_left_arm = action["left_joint"]
action_left_gripper = action["left_gripper_openness"]
action_right_arm = action["right_joint"]
action_right_gripper = action["right_gripper_openness"]
if action_left_arm.ndim - action_left_gripper.ndim == 1:
if action_left_gripper.ndim == 0:
action_left_gripper = action_left_gripper[None]
action_right_gripper = action_right_gripper[None]
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=0)
elif action_left_gripper.ndim == 1:
action_left_gripper = action_left_gripper[:, None]
action_right_gripper = action_right_gripper[:, None]
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=1)
return action
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions

View File

@@ -0,0 +1,185 @@
import dataclasses
from typing import ClassVar
import einops
import numpy as np
from openpi import transforms
from pdb import set_trace
@dataclasses.dataclass(frozen=True)
class SimFrankaInputs(transforms.DataTransformFn):
"""Inputs for the Franka policy.
"""
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_franka(data, adapt_to_pi=self.adapt_to_pi)
if "images" in data:
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": data["state"],
"pose": data["pose"],
}
else:
inputs = {
"state": data["state"],
"pose": data["pose"],
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = actions
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class SimFrankaOutputs(transforms.DataTransformFn):
"""Outputs for the Lift2 policy."""
adapt_to_pi: bool = True
def __call__(self, data: dict) -> dict:
# Only return the first 7 dims.
actions = np.asarray(data["actions"][:, :7])
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
def _joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def _normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def _unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def _gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
# There are 4096 total encoder counts and aloha uses a zero of 2048.
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
return _normalize(value, min_val=0.5476, max_val=1.6296)
def _gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# We do not scale the output since the trossen model predictions are already in radians.
# See the comment in _gripper_to_angular for a derivation of the constant
value = value + 0.5476
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return _normalize(value, min_val=-0.6213, max_val=1.4910)
def _gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
return value - 0.5476
def _decode_franka(data: dict, *, adapt_to_pi: bool = False) -> dict:
state_dict = data["state_dict"]
data["state"], data["pose"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
del data["state_dict"]
action_dict = data["action_dict"]
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
del data["action_dict"]
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c")
if "images" in data:
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
return data
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
gripper_position = state["gripper_position"][None]
gripper_pose = state["gripper_pose"]
joint_position = state["joint_position"]
state = np.concatenate([joint_position, gripper_position], axis=0)
pose = np.concatenate([gripper_pose, gripper_position], axis=0)
return state, pose
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
gripper_pose = action["gripper_pose"]
gripper_openness = action["gripper_openness"][..., None]
action = np.concatenate([gripper_pose, gripper_openness], axis=1)
return action
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions

View File

@@ -0,0 +1,208 @@
import dataclasses
from typing import ClassVar
import einops
import numpy as np
from openpi import transforms
from pdb import set_trace
@dataclasses.dataclass(frozen=True)
class SimSplitAlohaInputs(transforms.DataTransformFn):
"""Inputs for the Split Aloha policy.
"""
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
if "images" in data:
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": data["state"],
}
else:
inputs = {
"state": data["state"],
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = actions
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class SimSplitAlohaOutputs(transforms.DataTransformFn):
"""Outputs for the Split Aloha policy."""
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi: bool = True
def __call__(self, data: dict) -> dict:
# Only return the first 14 dims.
actions = np.asarray(data["actions"][:, :14])
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
def _joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def _normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def _unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def _gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
# There are 4096 total encoder counts and aloha uses a zero of 2048.
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
return _normalize(value, min_val=0.5476, max_val=1.6296)
def _gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# We do not scale the output since the trossen model predictions are already in radians.
# See the comment in _gripper_to_angular for a derivation of the constant
value = value + 0.5476
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return _normalize(value, min_val=-0.6213, max_val=1.4910)
def _gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
return value - 0.5476
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
# state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
# dim sizes: [7, 1, 7, 1]
state_dict = data["state_dict"]
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
del data["state_dict"]
action_dict = data["action_dict"]
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
del data["action_dict"]
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c")
if "images" in data:
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
return data
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
state_left_arm = state["left_joint"]
state_left_gripper = state["left_gripper"]
state_right_arm = state["right_joint"]
state_right_gripper = state["right_gripper"]
if state_left_arm.ndim - state_left_gripper.ndim == 1:
if state_left_gripper.ndim == 0:
state_left_gripper = state_left_gripper[None]
state_right_gripper = state_right_gripper[None]
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=0)
elif state_left_gripper.ndim == 1:
state_left_gripper = state_left_gripper[:, None]
state_right_gripper = state_right_gripper[:, None]
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=1)
return state
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
action_left_arm = action["left_joint"]
action_left_gripper = action["left_gripper_openness"]
action_right_arm = action["right_joint"]
action_right_gripper = action["right_gripper_openness"]
if action_left_arm.ndim - action_left_gripper.ndim == 1:
if action_left_gripper.ndim == 0:
action_left_gripper = action_left_gripper[None]
action_right_gripper = action_right_gripper[None]
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=0)
elif action_left_gripper.ndim == 1:
action_left_gripper = action_left_gripper[:, None]
action_right_gripper = action_right_gripper[:, None]
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=1)
return action
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
return actions

View File

@@ -0,0 +1,90 @@
import asyncio
import http
import logging
import time
import traceback
from openpi_client import base_policy as _base_policy
from openpi_client import msgpack_numpy
import websockets.asyncio.server as _server
import websockets.frames
logger = logging.getLogger(__name__)
class WebsocketPolicyServer:
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
Currently only implements the `load` and `infer` methods.
"""
def __init__(
self,
policy: _base_policy.BasePolicy,
host: str = "0.0.0.0",
port: int | None = None,
metadata: dict | None = None,
) -> None:
self._policy = policy
self._host = host
self._port = port
self._metadata = metadata or {}
logging.getLogger("websockets.server").setLevel(logging.INFO)
def serve_forever(self) -> None:
asyncio.run(self.run())
async def run(self):
async with _server.serve(
self._handler,
self._host,
self._port,
compression=None,
max_size=None,
process_request=_health_check,
) as server:
await server.serve_forever()
async def _handler(self, websocket: _server.ServerConnection):
logger.info(f"Connection from {websocket.remote_address} opened")
packer = msgpack_numpy.Packer()
await websocket.send(packer.pack(self._metadata))
prev_total_time = None
while True:
try:
start_time = time.monotonic()
obs = msgpack_numpy.unpackb(await websocket.recv())
infer_time = time.monotonic()
action = self._policy.infer(obs)
infer_time = time.monotonic() - infer_time
action["server_timing"] = {
"infer_ms": infer_time * 1000,
}
if prev_total_time is not None:
# We can only record the last total time since we also want to include the send time.
action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
await websocket.send(packer.pack(action))
prev_total_time = time.monotonic() - start_time
except websockets.ConnectionClosed:
logger.info(f"Connection from {websocket.remote_address} closed")
break
except Exception:
await websocket.send(traceback.format_exc())
await websocket.close(
code=websockets.frames.CloseCode.INTERNAL_ERROR,
reason="Internal server error. Traceback included in previous frame.",
)
raise
def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
if request.path == "/healthz":
return connection.respond(http.HTTPStatus.OK, "OK\n")
# Continue with the normal request handling.
return None

View File

@@ -0,0 +1,89 @@
import contextlib
import functools as ft
import inspect
from typing import TypeAlias, TypeVar, cast
import beartype
import jax
import jax._src.tree_util as private_tree_util
import jax.core
from jaxtyping import ArrayLike
from jaxtyping import Bool # noqa: F401
from jaxtyping import DTypeLike # noqa: F401
from jaxtyping import Float
from jaxtyping import Int # noqa: F401
from jaxtyping import Key # noqa: F401
from jaxtyping import Num # noqa: F401
from jaxtyping import PyTree
from jaxtyping import Real # noqa: F401
from jaxtyping import UInt8 # noqa: F401
from jaxtyping import config
from jaxtyping import jaxtyped
import jaxtyping._decorator
import torch
# patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277.
# the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`,
# `jax.Sharding`, or even <object>) due to JAX tracing operations. this patch skips typechecking when the stack trace
# contains `jax._src.tree_util`, which should only be the case during tree unflattening.
_original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations # noqa: SLF001
# Redefine Array to include both JAX arrays and PyTorch tensors
Array = jax.Array | torch.Tensor
def _check_dataclass_annotations(self, typechecker):
if not any(
frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"}
for frame in inspect.stack()
):
return _original_check_dataclass_annotations(self, typechecker)
return None
jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001
KeyArrayLike: TypeAlias = jax.typing.ArrayLike
Params: TypeAlias = PyTree[Float[ArrayLike, "..."]]
T = TypeVar("T")
# runtime type-checking decorator
def typecheck(t: T) -> T:
return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t))
@contextlib.contextmanager
def disable_typechecking():
initial = config.jaxtyping_disable
config.update("jaxtyping_disable", True) # noqa: FBT003
yield
config.update("jaxtyping_disable", initial)
def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False):
"""Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer
error message than if `jax.tree.map` is naively used on PyTrees with different structures.
"""
if errors := list(private_tree_util.equality_errors(expected, got)):
raise ValueError(
"PyTrees have different structure:\n"
+ (
"\n".join(
f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n"
for path, thing1, thing2, explanation in errors
)
)
)
if check_shapes or check_dtypes:
def check(kp, x, y):
if check_shapes and x.shape != y.shape:
raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}")
if check_dtypes and x.dtype != y.dtype:
raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}")
jax.tree_util.tree_map_with_path(check, expected, got)

View File

@@ -0,0 +1,194 @@
import concurrent.futures
import datetime
import logging
import os
import pathlib
import re
import shutil
import stat
import time
import urllib.parse
import filelock
import fsspec
import fsspec.generic
import tqdm_loggable.auto as tqdm
# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME"
DEFAULT_CACHE_DIR = "~/.cache/openpi"
logger = logging.getLogger(__name__)
def get_cache_dir() -> pathlib.Path:
cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve()
cache_dir.mkdir(parents=True, exist_ok=True)
_set_folder_permission(cache_dir)
return cache_dir
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
If the local file already exists, it will be returned directly.
It is safe to call this function concurrently from multiple processes.
See `get_cache_dir` for more details on the cache directory.
Args:
url: URL to the file to download.
force_download: If True, the file will be downloaded even if it already exists in the cache.
**kwargs: Additional arguments to pass to fsspec.
Returns:
Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute.
"""
# Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem.
parsed = urllib.parse.urlparse(url)
# Short circuit if this is a local path.
if parsed.scheme == "":
path = pathlib.Path(url)
if not path.exists():
raise FileNotFoundError(f"File not found at {url}")
return path.resolve()
cache_dir = get_cache_dir()
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
local_path = local_path.resolve()
# Check if the cache should be invalidated.
invalidate_cache = False
if local_path.exists():
if force_download or _should_invalidate_cache(cache_dir, local_path):
invalidate_cache = True
else:
return local_path
try:
lock_path = local_path.with_suffix(".lock")
with filelock.FileLock(lock_path):
# Ensure consistent permissions for the lock file.
_ensure_permissions(lock_path)
# First, remove the existing cache if it is expired.
if invalidate_cache:
logger.info(f"Removing expired cached entry: {local_path}")
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
# Download the data to a local cache.
logger.info(f"Downloading {url} to {local_path}")
scratch_path = local_path.with_suffix(".partial")
_download_fsspec(url, scratch_path, **kwargs)
shutil.move(scratch_path, local_path)
_ensure_permissions(local_path)
except PermissionError as e:
msg = (
f"Local file permission error was encountered while downloading {url}. "
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
)
raise PermissionError(msg) from e
return local_path
def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
"""Download a file from a remote filesystem to the local cache, and return the local path."""
fs, _ = fsspec.core.url_to_fs(url, **kwargs)
info = fs.info(url)
# Folders are represented by 0-byte objects with a trailing forward slash.
if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))):
total_size = fs.du(url)
else:
total_size = info["size"]
with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar:
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = executor.submit(fs.get, url, local_path, recursive=is_dir)
while not future.done():
current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file())
pbar.update(current_size - pbar.n)
time.sleep(1)
pbar.update(total_size - pbar.n)
def _set_permission(path: pathlib.Path, target_permission: int):
"""chmod requires executable permission to be set, so we skip if the permission is already match with the target."""
if path.stat().st_mode & target_permission == target_permission:
logger.debug(f"Skipping {path} because it already has correct permissions")
return
path.chmod(target_permission)
logger.debug(f"Set {path} to {target_permission}")
def _set_folder_permission(folder_path: pathlib.Path) -> None:
"""Set folder permission to be read, write and searchable."""
_set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
def _ensure_permissions(path: pathlib.Path) -> None:
"""Since we are sharing cache directory with containerized runtime as well as training script, we need to
ensure that the cache directory has the correct permissions.
"""
def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None:
cache_dir = get_cache_dir()
relative_path = path.relative_to(cache_dir)
moving_path = cache_dir
for part in relative_path.parts:
_set_folder_permission(moving_path / part)
moving_path = moving_path / part
def _set_file_permission(file_path: pathlib.Path) -> None:
"""Set all files to be read & writable, if it is a script, keep it as a script."""
file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH
if file_path.stat().st_mode & 0o100:
_set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
else:
_set_permission(file_path, file_rw)
_setup_folder_permission_between_cache_dir_and_path(path)
for root, dirs, files in os.walk(str(path)):
root_path = pathlib.Path(root)
for file in files:
file_path = root_path / file
_set_file_permission(file_path)
for dir in dirs:
dir_path = root_path / dir
_set_folder_permission(dir_path)
def _get_mtime(year: int, month: int, day: int) -> float:
"""Get the mtime of a given date at midnight UTC."""
date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)
return time.mktime(date.timetuple())
# Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format).
# Partial matching will be used from top to bottom and the first match will be chosen.
# Cached entries will be retained only if they are newer than the expiration timestamp.
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17),
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
}
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
assert local_path.exists(), f"File not found at {local_path}"
relative_path = str(local_path.relative_to(cache_dir))
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
if pattern.match(relative_path):
# Remove if not newer than the expiration timestamp.
return local_path.stat().st_mtime <= expire_time
return False

View File

@@ -0,0 +1,54 @@
import pathlib
import pytest
import openpi.shared.download as download
@pytest.fixture(scope="session", autouse=True)
def set_openpi_data_home(tmp_path_factory):
temp_dir = tmp_path_factory.mktemp("openpi_data")
with pytest.MonkeyPatch().context() as mp:
mp.setenv("OPENPI_DATA_HOME", str(temp_dir))
yield
def test_download_local(tmp_path: pathlib.Path):
local_path = tmp_path / "local"
local_path.touch()
result = download.maybe_download(str(local_path))
assert result == local_path
with pytest.raises(FileNotFoundError):
download.maybe_download("bogus")
def test_download_gs_dir():
remote_path = "gs://openpi-assets/testdata/random"
local_path = download.maybe_download(remote_path)
assert local_path.exists()
new_local_path = download.maybe_download(remote_path)
assert new_local_path == local_path
def test_download_gs():
remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin"
local_path = download.maybe_download(remote_path)
assert local_path.exists()
new_local_path = download.maybe_download(remote_path)
assert new_local_path == local_path
def test_download_fsspec():
remote_path = "gs://big_vision/paligemma_tokenizer.model"
local_path = download.maybe_download(remote_path, gs={"token": "anon"})
assert local_path.exists()
new_local_path = download.maybe_download(remote_path, gs={"token": "anon"})
assert new_local_path == local_path

View File

@@ -0,0 +1,126 @@
import functools
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F # noqa: N812
import openpi.shared.array_typing as at
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
@at.typecheck
def resize_with_pad(
images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"],
height: int,
width: int,
method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR,
) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]:
"""Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion
by padding with black. If the image is float32, it must be in the range [-1, 1].
"""
has_batch_dim = images.ndim == 4
if not has_batch_dim:
images = images[None] # type: ignore
cur_height, cur_width = images.shape[1:3]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_images = jax.image.resize(
images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method
)
if images.dtype == jnp.uint8:
# round from float back to uint8
resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8)
elif images.dtype == jnp.float32:
resized_images = resized_images.clip(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
pad_h0, remainder_h = divmod(height - resized_height, 2)
pad_h1 = pad_h0 + remainder_h
pad_w0, remainder_w = divmod(width - resized_width, 2)
pad_w1 = pad_w0 + remainder_w
padded_images = jnp.pad(
resized_images,
((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)),
constant_values=0 if images.dtype == jnp.uint8 else -1.0,
)
if not has_batch_dim:
padded_images = padded_images[0]
return padded_images
def resize_with_pad_torch(
images: torch.Tensor,
height: int,
width: int,
mode: str = "bilinear",
) -> torch.Tensor:
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
by padding with black. If the image is float32, it must be in the range [-1, 1].
Args:
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
height: Target height
width: Target width
mode: Interpolation mode ('bilinear', 'nearest', etc.)
Returns:
Resized and padded tensor with same shape format as input
"""
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
if images.shape[-1] <= 4: # Assume channels-last format
channels_last = True
# Convert to channels-first for torch operations
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
else:
channels_last = False
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
batch_size, channels, cur_height, cur_width = images.shape
# Calculate resize ratio
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
# Resize
resized_images = F.interpolate(
images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None
)
# Handle dtype-specific clipping
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
# Calculate padding
pad_h0, remainder_h = divmod(height - resized_height, 2)
pad_h1 = pad_h0 + remainder_h
pad_w0, remainder_w = divmod(width - resized_width, 2)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
mode="constant",
value=constant_value,
)
# Convert back to original format if needed
if channels_last:
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
if batch_size == 1 and images.shape[0] == 1:
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
return padded_images

View File

@@ -0,0 +1,37 @@
import jax.numpy as jnp
from openpi.shared import image_tools
def test_resize_with_pad_shapes():
# Test case 1: Resize image with larger dimensions
images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels)
height = 20
width = 20
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (2, height, width, 3)
assert jnp.all(resized_images == 0)
# Test case 2: Resize image with smaller dimensions
images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8)
height = 15
width = 15
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (3, height, width, 3)
assert jnp.all(resized_images == 0)
# Test case 3: Resize image with the same dimensions
images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8)
height = 50
width = 50
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert jnp.all(resized_images == 0)
# Test case 3: Resize image with odd-numbered padding
images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8)
height = 60
width = 80
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert jnp.all(resized_images == 0)

View File

@@ -0,0 +1,69 @@
from collections.abc import Callable
import dataclasses
import functools
import inspect
import re
from typing import Any, ParamSpec, TypeVar
import flax.nnx as nnx
import jax
P = ParamSpec("P")
R = TypeVar("R")
def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]:
"""A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process.
Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much
more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module
mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must
traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details.
`module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by
`module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was
when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded
after the method call completes.
"""
if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)):
raise ValueError("module_jit must only be used on bound methods of nnx.Modules.")
graphdef, state = nnx.split(meth.__self__)
def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R:
module = nnx.merge(graphdef, state)
return meth.__func__(module, *args, **kwargs)
jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs)
@functools.wraps(meth)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return jitted_fn(state, *args, **kwargs)
return wrapper
@dataclasses.dataclass(frozen=True)
class PathRegex:
"""NNX Filter that matches paths using a regex.
By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument.
"""
pattern: str | re.Pattern
sep: str = "/"
def __post_init__(self):
if not isinstance(self.pattern, re.Pattern):
object.__setattr__(self, "pattern", re.compile(self.pattern))
def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool:
joined_path = self.sep.join(str(x) for x in path)
assert isinstance(self.pattern, re.Pattern)
return self.pattern.fullmatch(joined_path) is not None
def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State:
"""Apply a function to the leaves of the state that match the filter."""
filtered_keys = set(state.filter(filter).flat_state())
return state.map(lambda k, v: fn(v) if k in filtered_keys else v)

View File

@@ -0,0 +1,199 @@
import json
import pathlib
import numpy as np
import numpydantic
import pydantic
@pydantic.dataclasses.dataclass
class NormStats:
mean: numpydantic.NDArray
std: numpydantic.NDArray
q01: numpydantic.NDArray | None = None # 1st quantile
q99: numpydantic.NDArray | None = None # 99th quantile
class RunningStats:
"""Compute running statistics of a batch of vectors."""
def __init__(self):
self._count = 0
self._mean = None
self._mean_of_squares = None
self._min = None
self._max = None
self._histograms = None
self._bin_edges = None
self._num_quantile_bins = 5000 # for computing quantiles on the fly
def update(self, batch: np.ndarray) -> None:
"""
Update the running statistics with a batch of vectors.
Args:
vectors (np.ndarray): An array where all dimensions except the last are batch dimensions.
"""
batch = batch.reshape(-1, batch.shape[-1])
num_elements, vector_length = batch.shape
if self._count == 0:
self._mean = np.mean(batch, axis=0)
self._mean_of_squares = np.mean(batch**2, axis=0)
self._min = np.min(batch, axis=0)
self._max = np.max(batch, axis=0)
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
self._bin_edges = [
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
for i in range(vector_length)
]
else:
if vector_length != self._mean.size:
raise ValueError("The length of new vectors does not match the initialized vector length.")
new_max = np.max(batch, axis=0)
new_min = np.min(batch, axis=0)
max_changed = np.any(new_max > self._max)
min_changed = np.any(new_min < self._min)
self._max = np.maximum(self._max, new_max)
self._min = np.minimum(self._min, new_min)
if max_changed or min_changed:
self._adjust_histograms()
self._count += num_elements
batch_mean = np.mean(batch, axis=0)
batch_mean_of_squares = np.mean(batch**2, axis=0)
# Update running mean and mean of squares.
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count)
self._update_histograms(batch)
def get_statistics(self) -> NormStats:
"""
Compute and return the statistics of the vectors processed so far.
Returns:
dict: A dictionary containing the computed statistics.
"""
if self._count < 2:
raise ValueError("Cannot compute statistics for less than 2 vectors.")
variance = self._mean_of_squares - self._mean**2
stddev = np.sqrt(np.maximum(0, variance))
q01, q99 = self._compute_quantiles([0.01, 0.99])
return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99)
def _adjust_histograms(self):
"""Adjust histograms when min or max changes."""
for i in range(len(self._histograms)):
old_edges = self._bin_edges[i]
new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1)
# Redistribute the existing histogram counts to the new bins
new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i])
self._histograms[i] = new_hist
self._bin_edges[i] = new_edges
def _update_histograms(self, batch: np.ndarray) -> None:
"""Update histograms with new vectors."""
for i in range(batch.shape[1]):
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
self._histograms[i] += hist
def _compute_quantiles(self, quantiles):
"""Compute quantiles based on histograms."""
results = []
for q in quantiles:
target_count = q * self._count
q_values = []
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
cumsum = np.cumsum(hist)
idx = np.searchsorted(cumsum, target_count)
q_values.append(edges[idx])
results.append(np.array(q_values))
return results
class OptimizedRunningStats:
def __init__(self, num_quantile_bins=1000): # 减少bin数量
self._count = 0
self._sum = None
self._sum_sq = None
self._min = None
self._max = None
self._all_samples = [] # 用于存储采样数据
self._sample_rate = 0.01 # 1%采样率
self._num_quantile_bins = num_quantile_bins
def update(self, batch: np.ndarray) -> None:
batch = batch.reshape(-1, batch.shape[-1])
num_elements = batch.shape[0]
# 更新基本统计量(向量化)
if self._count == 0:
self._sum = np.sum(batch, axis=0, dtype=np.float64)
self._sum_sq = np.sum(batch**2, axis=0, dtype=np.float64)
self._min = np.min(batch, axis=0)
self._max = np.max(batch, axis=0)
else:
self._sum += np.sum(batch, axis=0, dtype=np.float64)
self._sum_sq += np.sum(batch**2, axis=0, dtype=np.float64)
self._min = np.minimum(self._min, np.min(batch, axis=0))
self._max = np.maximum(self._max, np.max(batch, axis=0))
# 随机采样用于分位数计算(避免存储所有数据)
if np.random.random() < self._sample_rate:
sample_idx = np.random.randint(0, num_elements, size=min(100, num_elements))
self._all_samples.append(batch[sample_idx])
self._count += num_elements
def get_statistics(self):
if self._count < 2:
raise ValueError("Cannot compute statistics for less than 2 vectors.")
# 计算均值和标准差
mean = self._sum / self._count
variance = (self._sum_sq / self._count) - mean**2
stddev = np.sqrt(np.maximum(0, variance))
# 基于采样数据计算分位数
if self._all_samples:
all_sampled = np.concatenate(self._all_samples, axis=0)
q01 = np.quantile(all_sampled, 0.01, axis=0)
q99 = np.quantile(all_sampled, 0.99, axis=0)
else:
q01 = np.zeros_like(mean)
q99 = np.zeros_like(mean)
return NormStats(mean=mean, std=stddev, q01=q01, q99=q99)
class _NormStatsDict(pydantic.BaseModel):
norm_stats: dict[str, NormStats]
def serialize_json(norm_stats: dict[str, NormStats]) -> str:
"""Serialize the running statistics to a JSON string."""
return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2)
def deserialize_json(data: str) -> dict[str, NormStats]:
"""Deserialize the running statistics from a JSON string."""
return _NormStatsDict(**json.loads(data)).norm_stats
def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None:
"""Save the normalization stats to a directory."""
path = pathlib.Path(directory) / "norm_stats.json"
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(serialize_json(norm_stats))
def load(directory: pathlib.Path | str) -> dict[str, NormStats]:
"""Load the normalization stats from a directory."""
path = pathlib.Path(directory) / "norm_stats.json"
if not path.exists():
raise FileNotFoundError(f"Norm stats file not found at: {path}")
return deserialize_json(path.read_text())

View File

@@ -0,0 +1,43 @@
import numpy as np
import openpi.shared.normalize as normalize
def test_normalize_update():
arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3
stats = normalize.RunningStats()
for i in range(len(arr)):
stats.update(arr[i : i + 1]) # Update with one vector at a time
results = stats.get_statistics()
assert np.allclose(results.mean, np.mean(arr, axis=0))
assert np.allclose(results.std, np.std(arr, axis=0))
def test_serialize_deserialize():
stats = normalize.RunningStats()
stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3
norm_stats = {"test": stats.get_statistics()}
norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats))
assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean)
assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std)
def test_multiple_batch_dimensions():
# Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension
batch_shape = (2, 3, 4)
arr = np.random.rand(*batch_shape)
stats = normalize.RunningStats()
stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4)
results = stats.get_statistics()
# Flatten batch dimensions and compute expected stats
flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4)
expected_mean = np.mean(flattened, axis=0)
expected_std = np.std(flattened, axis=0)
assert np.allclose(results.mean, expected_mean)
assert np.allclose(results.std, expected_std)

View File

@@ -0,0 +1,96 @@
"""Compute normalization statistics for a config.
This script is used to compute the normalization statistics for a given config. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config assets directory.
"""
import numpy as np
import tqdm
import tyro
import openpi.models.model as _model
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.mixture_dataset as _mixture_dataset
import openpi.training.data_loader as _data_loader
import openpi.transforms as transforms
from pdb import set_trace
import openpi.training.weight_loaders as weight_loaders
import openpi.models.pi0_config as pi0_config
# from openpi.training.config import MultiSimGenieDataConfig, MultiSimSplitAlohaDataConfig, MultiSimFrankaDataConfig, MultiLeRobotReala2dDataConfig, MultiLeRobotRealArxLift2DataConfig, MultiDataConfig, DataConfig, TrainConfig
import logging
from pdb import set_trace
from typing import List
class RemoveStrings(transforms.DataTransformFn):
def __call__(self, x: dict) -> dict:
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
def create_torch_dataloader(
data_config: List[_config.DataConfig],
action_horizon: int,
batch_size: int,
model_config: _model.BaseModelConfig,
num_workers: int,
max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
# if data_config.repo_id is None:
# raise ValueError("Data config must have a repo_id")
# dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)
dataset = _mixture_dataset.create_mixture_dataset_no_transform(data_config, action_horizon, model_config)
# from pdb import set_trace; set_trace()
dataset = _data_loader.TransformedDataset(
dataset,
[
*data_config[0][0].repack_transforms.inputs,
*data_config[0][0].data_transforms.inputs,
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
RemoveStrings(),
],
)
if max_frames is not None and max_frames < len(dataset):
num_batches = max_frames // batch_size
shuffle = True
else:
num_batches = len(dataset) // batch_size
shuffle = False
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
num_batches=num_batches,
)
return data_loader, num_batches
def compute_norm_stats(config_name: str, max_frames: int | None = None):
config = _config.get_config(config_name)
data_configs_list = []
for data_config_factory in config.data:
data_configs = data_config_factory.create(config.model)
logging.info(f"data_config: {data_configs}")
data_configs_list.append(data_configs)
print("done")
data_loader, num_batches = create_torch_dataloader(
data_configs_list, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
)
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
# stats = {key: normalize.OptimizedRunningStats() for key in keys} # 新的
# set_trace()
step_id = 0
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
step_id += 1
for key in keys:
stats[key].update(np.asarray(batch[key]))
if step_id > 10000:
break
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
print(norm_stats)
return norm_stats

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
import asyncio
import concurrent.futures as futures
import dataclasses
import logging
from typing import Protocol
from etils import epath
import jax
import orbax.checkpoint as ocp
import orbax.checkpoint.future as future
from openpi.shared import array_typing as at
import openpi.shared.normalize as _normalize
import openpi.training.data_loader as _data_loader
import openpi.training.utils as training_utils
def initialize_checkpoint_dir(
checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
) -> tuple[ocp.CheckpointManager, bool]:
checkpoint_dir = epath.Path(checkpoint_dir).resolve()
resuming = False
if checkpoint_dir.exists():
if overwrite:
checkpoint_dir.rmtree()
checkpoint_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
elif resume:
resuming = True
else:
raise FileExistsError(
f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
"to indicate how to handle it."
)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
mngr = ocp.CheckpointManager(
checkpoint_dir,
item_handlers={
"assets": CallbackHandler(),
"train_state": ocp.PyTreeCheckpointHandler(),
"params": ocp.PyTreeCheckpointHandler(),
},
options=ocp.CheckpointManagerOptions(
max_to_keep=1,
keep_period=keep_period,
create=False,
async_options=ocp.AsyncOptions(timeout_secs=7200),
),
)
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
# checkpoint, since it will fail.
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
resuming = False
return mngr, resuming
def save_state(
checkpoint_manager: ocp.CheckpointManager,
state: training_utils.TrainState,
data_loader: _data_loader.DataLoader,
step: int,
):
def save_assets(directory: epath.Path):
# Save the normalization stats.
data_config = data_loader.data_config()
norm_stats = data_config.norm_stats
if norm_stats is not None and data_config.asset_id is not None:
_normalize.save(directory / data_config.asset_id, norm_stats)
# Split params that can be used for inference into a separate item.
with at.disable_typechecking():
train_state, params = _split_params(state)
items = {
"assets": save_assets,
"train_state": train_state,
"params": {"params": params},
}
checkpoint_manager.save(step, items)
def restore_state(
checkpoint_manager: ocp.CheckpointManager,
state: training_utils.TrainState,
data_loader: _data_loader.DataLoader,
step: int | None = None,
) -> training_utils.TrainState:
del data_loader
with at.disable_typechecking():
# Split params that can be used for inference into a separate item.
train_state, params = _split_params(state)
restored = checkpoint_manager.restore(
step,
items={
"train_state": train_state,
"params": {"params": params},
},
)
return _merge_params(restored["train_state"], restored["params"])
def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:
norm_stats_dir = epath.Path(assets_dir) / asset_id
norm_stats = _normalize.load(norm_stats_dir)
logging.info(f"Loaded norm stats from {norm_stats_dir}")
return norm_stats
class Callback(Protocol):
def __call__(self, directory: epath.Path) -> None: ...
class CallbackHandler(ocp.AsyncCheckpointHandler):
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
def save(self, directory: epath.Path, args: CallbackSave):
if jax.process_index() == 0:
args.callback(directory)
async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
def restore(self, *args, **kwargs):
raise NotImplementedError("CallbackHandler does not support restore")
@ocp.args.register_with_handler(CallbackHandler, for_save=True)
@dataclasses.dataclass
class CallbackSave(ocp.args.CheckpointArgs):
callback: Callback
@ocp.args.register_with_handler(CallbackHandler, for_restore=True)
class CallbackRestore(ocp.args.CheckpointArgs): ...
def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:
if state.ema_params is not None:
params = state.ema_params
train_state = dataclasses.replace(state, ema_params=None)
else:
params = state.params
train_state = dataclasses.replace(state, params={})
return train_state, params
def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:
# Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.
if train_state.params:
return dataclasses.replace(train_state, ema_params=params["params"])
return dataclasses.replace(train_state, params=params["params"])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,721 @@
from collections.abc import Iterator, Sequence
import logging
import multiprocessing
import os
import typing
from typing import Literal, Protocol, SupportsIndex, TypeVar, Dict
import sys
import jax
import jax.numpy as jnp
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
import numpy as np
import torch
from torch.utils.data import dataloader
from torch.multiprocessing import reductions
from multiprocessing.reduction import ForkingPickler
import openpi.shared.normalize as normalize
default_collate_func = dataloader.default_collate
import psutil
def default_collate_override(batch):
dataloader._use_shared_memory = False
return default_collate_func(batch)
setattr(dataloader, 'default_collate', default_collate_override)
for t in torch._storage_classes:
if sys.version_info[0] == 2:
if t in ForkingPickler.dispatch:
del ForkingPickler.dispatch[t]
else:
if t in ForkingPickler._extra_reducers:
del ForkingPickler._extra_reducers[t]
import openpi.models.model as _model
import openpi.training.config as _config
from openpi.training.droid_rlds_dataset import DroidRldsDataset
import openpi.transforms as _transforms
from openpi.training.mixture_dataset import create_mixture_dataset
T_co = TypeVar("T_co", covariant=True)
import copy
from memory_profiler import profile
from pdb import set_trace
class Dataset(Protocol[T_co]):
"""Interface for a dataset with random access."""
def __getitem__(self, index: SupportsIndex) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
def __len__(self) -> int:
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
class IterableDataset(Protocol[T_co]):
"""Interface for an iterable dataset."""
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.")
def __len__(self) -> int:
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
class DataLoader(Protocol[T_co]):
"""Interface for a data loader."""
def data_config(self) -> _config.DataConfig:
"""Get the data config for this data loader."""
raise NotImplementedError("Subclasses of DataLoader should implement data_config.")
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError("Subclasses of DataLoader should implement __iter__.")
class TransformedDataset(Dataset[T_co]):
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
self._dataset = dataset
self._transform = _transforms.compose(transforms)
def __getitem__(self, index: SupportsIndex) -> T_co:
return self._transform(self._dataset[index])
def __len__(self) -> int:
return len(self._dataset)
class IterableTransformedDataset(IterableDataset[T_co]):
def __init__(
self,
dataset: IterableDataset,
transforms: Sequence[_transforms.DataTransformFn],
*,
is_batched: bool = False,
):
self._dataset = dataset
self._transform = _transforms.compose(transforms)
self._is_batched = is_batched
def __iter__(self):
for sample in self._dataset:
if self._is_batched:
# Transforms are designed to be applied to individual samples. So we need to split the batch into
# individual samples and apply the transform to each sample individually.
batch_size = next(v.shape[0] for v in sample.values())
# Split batch into individual samples using tree_map
individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023
# Transform each sample
transformed = [self._transform(s) for s in individual_samples]
# Recombine batch with tree_map
yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed)
else:
yield self._transform(sample)
def __len__(self) -> int:
return len(self._dataset)
class FakeDataset(Dataset):
def __init__(self, model_config: _model.BaseModelConfig, num_samples: int):
self._num_samples = num_samples
self._observation_spec, self._action_spec = model_config.inputs_spec()
def __getitem__(self, index: SupportsIndex) -> dict:
rng = jax.random.key(index.__index__())
def make_from_spec(spec: jax.ShapeDtypeStruct):
nonlocal rng
rng, data_rng = jax.random.split(rng)
# Remove the batch dimension.
shape = spec.shape[1:]
if spec.dtype == jnp.float32:
return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0)
if spec.dtype == jnp.int32:
return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048)
return jnp.zeros(shape=shape, dtype=spec.dtype)
observation = jax.tree.map(make_from_spec, self._observation_spec)
action = jax.tree.map(make_from_spec, self._action_spec)
return {
**observation.to_dict(),
"actions": action,
}
def __len__(self) -> int:
return self._num_samples
def create_torch_dataset(
data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
) -> Dataset:
"""Create a dataset for training."""
repo_id = data_config.repo_id
if repo_id is None:
raise ValueError("Repo ID is not set. Cannot create dataset.")
if repo_id == "fake":
return FakeDataset(model_config, num_samples=1024)
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
dataset = lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
},
)
if data_config.prompt_from_task:
dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
return dataset
def create_rlds_dataset(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
*,
shuffle: bool = False,
) -> Dataset:
# At the moment, we only support DROID for RLDS datasets.
return DroidRldsDataset(
data_dir=data_config.rlds_data_dir,
batch_size=batch_size,
shuffle=shuffle,
action_chunk_size=action_horizon,
action_space=data_config.action_space,
filter_dict_path=data_config.filter_dict_path,
)
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:
"""Transform the dataset by applying the data transforms."""
norm_stats = {}
if data_config.repo_id != "fake" and not skip_norm_stats:
if data_config.norm_stats is None:
raise ValueError(
"Normalization stats not found. "
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
)
norm_stats = data_config.norm_stats
return TransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
)
def transform_iterable_dataset(
dataset: IterableDataset,
data_config: _config.DataConfig,
*,
skip_norm_stats: bool = False,
is_batched: bool = False,
) -> IterableDataset:
"""Transform the dataset by applying the data transforms."""
norm_stats = {}
if data_config.repo_id != "fake" and not skip_norm_stats:
if data_config.norm_stats is None:
raise ValueError(
"Normalization stats not found. "
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
)
norm_stats = data_config.norm_stats
return IterableTransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
is_batched=is_batched,
)
def create_data_loader(
config: _config.TrainConfig,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
num_batches: int | None = None,
skip_norm_stats: bool = False,
framework: Literal["jax", "pytorch"] = "jax",
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
config: The training configuration.
sharding: The sharding to use for the data loader (JAX only).
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return.
skip_norm_stats: Whether to skip data normalization.
framework: The framework to use ("jax" or "pytorch").
"""
data_config = config.data.create(config.assets_dirs, config.model)
logging.info(f"data_config: {data_config}")
if data_config.rlds_data_dir is not None:
return create_rlds_data_loader(
data_config,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
skip_norm_stats=skip_norm_stats,
framework=framework,
)
return create_torch_data_loader(
data_config,
model_config=config.model,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
num_workers=config.num_workers,
seed=config.seed,
skip_norm_stats=skip_norm_stats,
framework=framework,
)
def create_data_loader_multi(
config: _config.TrainConfig,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
num_batches: int | None = None,
skip_norm_stats: bool = False,
framework: Literal["jax", "pytorch"] = "jax",
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
config: The training configuration.
sharding: The sharding to use for the data loader (JAX only).
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return.
skip_norm_stats: Whether to skip data normalization.
framework: The framework to use ("jax" or "pytorch").
"""
data_configs_list = []
for data_config_factory in config.data:
data_configs = data_config_factory.create(config.model, global_norm_stats)
logging.info(f"data_config: {data_configs}")
data_configs_list.append(data_configs)
return create_torch_data_loader_multi(
data_configs_list,
model_config=config.model,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
num_workers=config.num_workers,
seed=config.seed,
skip_norm_stats=skip_norm_stats,
framework=framework,
global_norm_stats=global_norm_stats,
)
def create_torch_data_loader(
data_config: _config.DataConfig,
model_config: _model.BaseModelConfig,
action_horizon: int,
batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
framework: str = "jax",
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
data_config: The data configuration.
action_horizon: The action horizon.
batch_size: The batch size.
sharding: The sharding to use for the data loader. If None, the data loader will
use a single device sharding.
skip_norm_stats: Whether to skip data normalization.
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return. If the number exceeds the
number of batches in the dataset, the data loader will loop over the dataset.
If not provided, will iterate over the dataset indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
dataset = create_torch_dataset(data_config, action_horizon, model_config)
dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)
# Use TorchDataLoader for both frameworks
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
# For JAX, divide by process count
sampler = None
if framework == "pytorch":
if torch.distributed.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
shuffle=shuffle,
drop_last=True,
)
local_batch_size = batch_size // torch.distributed.get_world_size()
else:
local_batch_size = batch_size
else:
local_batch_size = batch_size // jax.process_count()
if jax.process_count() > 1:
sampler = JaxProcessDistributedSampler(
dataset_size=len(dataset),
num_replicas=jax.process_count(),
rank=jax.process_index(),
shuffle=shuffle,
seed=seed,
)
logging.info(f"local_batch_size: {local_batch_size}")
data_loader = TorchDataLoader(
dataset,
local_batch_size=local_batch_size,
sharding=None if framework == "pytorch" else sharding,
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
sampler=sampler,
num_batches=num_batches,
num_workers=num_workers,
seed=seed,
framework=framework,
)
return DataLoaderImpl(data_config, data_loader)
def create_torch_data_loader_multi(
data_configs_list: list[_config.DataConfig],
model_config: _model.BaseModelConfig,
action_horizon: int,
batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
framework: str = "jax",
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
data_config: The data configuration.
action_horizon: The action horizon.
batch_size: The batch size.
sharding: The sharding to use for the data loader. If None, the data loader will
use a single device sharding.
skip_norm_stats: Whether to skip data normalization.
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return. If the number exceeds the
number of batches in the dataset, the data loader will loop over the dataset.
If not provided, will iterate over the dataset indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
dataset = create_mixture_dataset(data_configs_list, action_horizon, model_config)
# Use TorchDataLoader for both frameworks
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
# For JAX, divide by process count
sampler = None
if framework == "pytorch":
if torch.distributed.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
shuffle=shuffle,
drop_last=True,
)
local_batch_size = batch_size // torch.distributed.get_world_size()
else:
local_batch_size = batch_size
else:
local_batch_size = batch_size // jax.process_count()
if jax.process_count() > 1:
sampler = JaxProcessDistributedSampler(
dataset_size=len(dataset),
num_replicas=jax.process_count(),
rank=jax.process_index(),
shuffle=shuffle,
seed=seed,
)
logging.info(f"local_batch_size: {local_batch_size}")
data_loader = TorchDataLoader(
dataset,
local_batch_size=local_batch_size,
sharding=None if framework == "pytorch" else sharding,
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
sampler=sampler,
num_batches=num_batches,
num_workers=num_workers,
seed=seed,
framework=framework,
)
return DataLoaderImpl(data_configs_list[0][0], data_loader)
def create_rlds_data_loader(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
framework: str = "jax",
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create an RLDS data loader for training.
Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md
Args:
data_config: The data configuration.
action_horizon: The action horizon.
batch_size: The batch size.
sharding: The sharding to use for the data loader. If None, the data loader will
use a single device sharding.
skip_norm_stats: Whether to skip data normalization.
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return. If the number exceeds the
number of batches in the dataset, the data loader will loop over the dataset.
If not provided, will iterate over the dataset indefinitely.
"""
if framework == "pytorch":
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
data_loader = RLDSDataLoader(
dataset,
sharding=sharding,
num_batches=num_batches,
)
return DataLoaderImpl(data_config, data_loader)
class JaxProcessDistributedSampler(torch.utils.data.Sampler[int]):
"""Simple sampler to split dataset indices across JAX processes.
Each process sees a disjoint slice of indices using striding by num_replicas.
Shuffling (if enabled) is deterministic via the provided seed.
"""
def __init__(
self,
dataset_size: int,
*,
num_replicas: int,
rank: int,
shuffle: bool,
seed: int,
) -> None:
self._dataset_size = max(0, dataset_size)
self._num_replicas = max(1, num_replicas)
self._rank = max(0, rank)
self._shuffle = shuffle
self._seed = seed
def __iter__(self):
indices = list(range(self._dataset_size))
if self._shuffle and self._dataset_size > 0:
g = torch.Generator()
g.manual_seed(self._seed)
indices = torch.randperm(self._dataset_size, generator=g).tolist()
# Strided split across processes; drop remainder for balance
indices = indices[self._rank :: self._num_replicas]
return iter(indices)
def __len__(self) -> int:
# Match strided selection length
return (self._dataset_size + self._num_replicas - 1) // self._num_replicas
# @profile
class TorchDataLoader:
"""Torch data loader implementation."""
def __init__(
self,
dataset,
local_batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
sampler: torch.utils.data.Sampler | None = None,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
framework: str = "jax",
):
"""Create a PyTorch data loader.
Args:
dataset: The dataset to load.
local_batch_size: The local batch size for each process.
sharding: The sharding to use for the data loader.
shuffle: Whether to shuffle the data.
num_batches: If provided, determines the number of returned batches. If the
number is larger than the number of batches in the dataset, the data loader
will loop over the dataset. If not provided, will iterate over the dataset
indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
if len(dataset) < local_batch_size:
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
# Store sharding - None for PyTorch, JAX sharding for JAX
self._sharding = sharding
if sharding is None and framework == "jax":
# Use data parallel sharding by default for JAX only.
self._sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ("B",)),
jax.sharding.PartitionSpec("B"),
)
self._num_batches = num_batches
mp_context = None
if num_workers > 0:
mp_context = multiprocessing.get_context("spawn")
generator = torch.Generator()
generator.manual_seed(seed)
self._data_loader = torch.utils.data.DataLoader(
typing.cast(torch.utils.data.Dataset, dataset),
batch_size=local_batch_size,
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
sampler=sampler,
num_workers=num_workers,
multiprocessing_context=mp_context,
persistent_workers=num_workers > 0,
collate_fn=_collate_fn,
worker_init_fn=_worker_init_fn,
drop_last=True,
generator=generator,
pin_memory=False,
)
@property
def torch_loader(self) -> torch.utils.data.DataLoader:
return self._data_loader
@profile
def __iter__(self):
num_items = 0
while True:
data_iter = iter(self._data_loader)
while True:
if self._num_batches is not None and num_items >= self._num_batches:
return
try:
batch = next(data_iter)
except StopIteration:
break # We've exhausted the dataset. Create a new iterator and start over.
num_items += 1
# For JAX, convert to sharded arrays; for PyTorch, return torch tensors
if self._sharding is not None:
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
else:
yield jax.tree.map(torch.as_tensor, batch)
def _collate_fn(items):
"""Collate the batch elements into batched numpy arrays."""
# Make sure to convert to numpy arrays before stacking since some of the incoming elements
# may be JAX arrays.
return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)
def _worker_init_fn(worker_id: int) -> None:
"""Tell JAX inside the worker process not to preallocate the GPU memory."""
# NOTE: This is called after jax is imported inside the worker process. This
# means that this approach will not work for selecting the backend.
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
class RLDSDataLoader:
"""Shallow wrapper around the DROID data loader to make it compatible with openpi.
All batching already happens in the DROID dataset, so we don't need to do anything here.
"""
def __init__(
self,
dataset: DroidRldsDataset,
*,
sharding: jax.sharding.Sharding | None = None,
num_batches: int | None = None,
):
self._dataset = dataset
self._num_batches = num_batches
if jax.process_count() > 1:
raise NotImplementedError("Data loading with multiple processes is not supported.")
if sharding is None:
# Use data parallel sharding by default.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ("B",)),
jax.sharding.PartitionSpec("B"),
)
self._sharding = sharding
self._num_batches = num_batches
def __iter__(self):
num_items = 0
while True:
data_iter = iter(self._dataset)
while True:
if self._num_batches is not None and num_items >= self._num_batches:
return
try:
batch = next(data_iter)
except StopIteration:
break # We've exhausted the dataset. Create a new iterator and start over.
num_items += 1
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
class DataLoaderImpl(DataLoader):
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
self._data_config = data_config
self._data_loader = data_loader
def data_config(self) -> _config.DataConfig:
return self._data_config
def __iter__(self):
for batch in self._data_loader:
yield _model.Observation.from_dict(batch), batch["actions"]

View File

@@ -0,0 +1,221 @@
"""
RLDS-based data loader for DROID.
While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.
Thus, we provide a data loader example here that uses the RLDS data format.
The data loader also applies a few DROID-specific data filters / transformations.
"""
from enum import Enum
from enum import auto
import json
import logging
from pathlib import Path
import tqdm
import openpi.shared.download as download
class DroidActionSpace(Enum):
"""Action space for DROID dataset."""
JOINT_POSITION = auto()
JOINT_VELOCITY = auto()
class DroidRldsDataset:
def __init__(
self,
data_dir: str,
batch_size: int,
*, # Force keyword-only arguments
shuffle: bool = True,
action_chunk_size: int = 16,
# We default to joint position actions, since they allow policy evaluation in simulation.
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
max_loaded_steps_per_episode: int = 100,
# Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
shuffle_buffer_size: int = 250_000,
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
filter_dict_path=None, # Path to json file with indices to sample during training
):
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
import dlimp as dl
import tensorflow as tf
import tensorflow_datasets as tfds
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
tf.config.set_visible_devices([], "GPU")
builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1")
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
# Filter out any unsuccessful trajectories -- we use the file name to check this
dataset = dataset.filter(
lambda traj: tf.strings.regex_full_match(
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
)
)
# # Repeat dataset so we never run out of data.
dataset = dataset.repeat()
# Load the filter dictionary if provided.
# The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample
# (e.g.,
# {
# "<episode key>": [[0, 100], [200, 300]]
# }
# means keep frames 0-99 and 200-299).
if filter_dict_path is not None:
cached_filter_dict_path = download.maybe_download(filter_dict_path)
with Path(cached_filter_dict_path).open("r") as f:
filter_dict = json.load(f)
logging.info(f"Using filter dictionary with {len(filter_dict)} episodes")
keys_tensor = []
values_tensor = []
for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."):
for start, end in ranges:
for t in range(start, end):
frame_key = f"{episode_key}--{t}"
keys_tensor.append(frame_key)
values_tensor.append(True)
self.filter_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False
)
logging.info("Filter hash table initialized")
else:
self.filter_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True
)
def restructure(traj):
"""Reformat observation and action keys, sample language instruction."""
# Important: we use joint *position* action space -- easier to simulate!
actions = tf.concat(
(
(
traj["action_dict"]["joint_position"]
if action_space == DroidActionSpace.JOINT_POSITION
else traj["action_dict"]["joint_velocity"]
),
traj["action_dict"]["gripper_position"],
),
axis=-1,
)
# Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
# Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
exterior_img = tf.cond(
tf.random.uniform(shape=[]) > 0.5,
lambda: traj["observation"]["exterior_image_1_left"],
lambda: traj["observation"]["exterior_image_2_left"],
)
wrist_img = traj["observation"]["wrist_image_left"]
# Randomly sample one of the three language instructions
instruction = tf.random.shuffle(
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
)[0]
traj_len = tf.shape(traj["action"])[0]
indices = tf.as_string(tf.range(traj_len))
# Data filtering:
# Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,
# and each step's time step index. This will index into the filter hash table, and if it returns true,
# then the frame passes the filter.
step_id = (
traj["traj_metadata"]["episode_metadata"]["recording_folderpath"]
+ "--"
+ traj["traj_metadata"]["episode_metadata"]["file_path"]
+ "--"
+ indices
)
passes_filter = self.filter_table.lookup(step_id)
return {
"actions": actions,
"observation": {
"image": exterior_img,
"wrist_image": wrist_img,
"joint_position": traj["observation"]["joint_position"],
"gripper_position": traj["observation"]["gripper_position"],
},
"prompt": instruction,
"step_id": step_id,
"passes_filter": passes_filter,
}
dataset = dataset.traj_map(restructure, num_parallel_calls)
def chunk_actions(traj):
"""Splits episode into action chunks."""
traj_len = tf.shape(traj["actions"])[0]
# For each step in the trajectory, construct indices for the next n actions
action_chunk_indices = tf.broadcast_to(
tf.range(action_chunk_size)[None],
[traj_len, action_chunk_size],
) + tf.broadcast_to(
tf.range(traj_len)[:, None],
[traj_len, action_chunk_size],
)
# Cap to length of the sequence --> final chunks will repeat the last action
# This makes sense, since we are using absolute joint + gripper position actions
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
# Gather the actions for each chunk
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
return traj
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
# Flatten: map from trajectory dataset to dataset of individual action chunks
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
# Filter data that doesn't pass the filter
def filter_from_dict(frame):
return frame["passes_filter"]
dataset = dataset.filter(filter_from_dict)
# Remove "passes_filter" key from output
def remove_passes_filter(frame):
frame.pop("passes_filter")
return frame
dataset = dataset.map(remove_passes_filter)
# Decode images: RLDS saves encoded images, only decode now for efficiency
def decode_images(traj):
traj["observation"]["image"] = tf.io.decode_image(
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
)
traj["observation"]["wrist_image"] = tf.io.decode_image(
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
)
return traj
dataset = dataset.frame_map(decode_images, num_parallel_calls)
# Shuffle, batch
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.batch(batch_size)
# Note =>> Seems to reduce memory usage without affecting speed?
dataset = dataset.with_ram_budget(1)
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
yield from self.dataset.as_numpy_iterator()
def __len__(self):
# This is the approximate number of samples in DROID after filtering.
# Easier to hardcode than to iterate through the dataset and compute it.
return 20_000_000

View File

@@ -0,0 +1,116 @@
"""RoboArena baseline policy configs."""
from typing import TypeAlias
import openpi.models.model as _model
import openpi.models.pi0_config as pi0_config
import openpi.models.pi0_fast as pi0_fast
import openpi.models.tokenizer as _tokenizer
import openpi.policies.droid_policy as droid_policy
import openpi.transforms as _transforms
ModelType: TypeAlias = _model.ModelType
def get_roboarena_configs():
# Import here to avoid circular imports.
from openpi.training.config import AssetsConfig
from openpi.training.config import DataConfig
from openpi.training.config import SimpleDataConfig
from openpi.training.config import TrainConfig
return [
#
# RoboArena DROID baseline inference configs.
#
TrainConfig(
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
name="paligemma_binning_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
max_token_len=400,
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
name="paligemma_fast_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
name="paligemma_fast_specialist_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
fast_model_tokenizer=_tokenizer.FASTTokenizer,
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FSQ tokenizer.
name="paligemma_vq_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
fast_model_tokenizer=_tokenizer.FSQTokenizer,
fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
name="paligemma_diffusion_droid",
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
]

View File

@@ -0,0 +1,703 @@
import numpy as np
from dataclasses import dataclass
from typing import SupportsIndex, Sequence, List, Dict, Any, Tuple, Optional, Union, TypeVar, Protocol
import torch
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
import openpi.transforms as _transforms
from pdb import set_trace
import logging
T_co = TypeVar("T_co", covariant=True)
import openpi.training.config as _config
import openpi.shared.normalize as normalize
def detect_gripper_change_step(
dataset,
select_actions: list[str] = ["action"],
gripper_dim: int = -1,
threshold_method: str = "std_multiplier",
threshold_multiplier: float = 2.0,
min_threshold: float = 0.001,
max_threshold: float = 1.0,
plot_gripper_changes: bool = False,
):
"""
Detect the step of gripper change. Only work for the self-collected dataset.
Modifies the dataset in place by adding 'gripper_change_step_idx' attribute.
This version uses a sliding window of size 4 centered around non_zero_idx,
including the indices and removing duplicates.
Args:
dataset: LeRobotDataset instance
select_actions: List of action keys to process
gripper_dim: Dimension index for gripper in the action vector
threshold_method: Method to calculate threshold ('std_multiplier', 'percentile', 'absolute')
threshold_multiplier: Multiplier for std-based threshold
min_threshold: Minimum threshold value to avoid too sensitive detection
max_threshold: Maximum threshold value to avoid missing large changes
plot_gripper_changes: Whether to plot gripper changes visualization
"""
episode_lengths = [ep_dict["length"] for ep_dict in dataset.meta.episodes.values()]
cumulative_lengths = np.cumsum(episode_lengths)
all_window_indices = set() # Use a set for automatic deduplication
for action_key in select_actions:
action_values = dataset.hf_dataset[action_key]
delta_action = np.diff(action_values, axis=0)
# Handle episode boundaries
for end_idx in cumulative_lengths[:-1]:
if end_idx - 1 < len(delta_action) and end_idx - 2 >= 0:
delta_action[end_idx - 1] = delta_action[end_idx - 2]
elif end_idx - 1 < len(delta_action):
delta_action[end_idx - 1] = 0
if delta_action.ndim == 1:
delta_action = delta_action[:, np.newaxis]
assert delta_action.ndim == 2
# Extract gripper delta values
gripper_delta = delta_action[:, gripper_dim]
# Calculate threshold based on statistical properties
if threshold_method == "std_multiplier":
# Use standard deviation to filter out small tremors
std_val = np.std(gripper_delta)
threshold = threshold_multiplier * std_val
elif threshold_method == "percentile":
# Use percentile-based threshold (e.g., 90th percentile)
threshold = np.percentile(np.abs(gripper_delta), 85)
elif threshold_method == "absolute":
# Use absolute threshold
threshold = threshold_multiplier
else:
raise ValueError(f"Unknown threshold_method: {threshold_method}")
# Clamp threshold to reasonable bounds
threshold = np.clip(threshold, min_threshold, max_threshold)
# Find indices where gripper change exceeds threshold
significant_change_idx = np.where(np.abs(gripper_delta) > threshold)[0]
cur_window_indices = set()
for idx in significant_change_idx:
# Create a sliding window of size 4 centered around idx.
# The window should include [idx-2, idx-1, idx, idx+1].
# This means starting 2 before and ending 1 after.
window_start = idx - 2
window_end = idx + 1
# Generate indices for the current window and ensure they are non-negative
# and within the bounds of the original action_values length.
# The maximum index possible is len(action_values) - 1.
# Since delta_action is len(action_values) - 1, the index refers to
# the step *before* the change. So the max index we want is effectively
# len(action_values) - 1, which corresponds to the last valid step index.
# If the original index is `i`, delta_action[i] corresponds to the change
# from step `i` to `i+1`. We want to include step `i` and its neighbors.
# The maximum index for steps is `len(action_values) - 1`.
# So, the window indices should not exceed `len(action_values) - 1`.
max_possible_idx = len(action_values) - 1
# Ensure indices are within valid range [0, max_possible_idx]
current_window_indices = np.arange(
max(0, window_start), min(max_possible_idx + 1, window_end + 1)
)
for w_idx in current_window_indices:
cur_window_indices.add(w_idx)
all_window_indices.add(w_idx)
if plot_gripper_changes:
num_episodes_to_plot = 5
end_index_for_plot = cumulative_lengths[num_episodes_to_plot - 1] - 1
delta_action_to_plot = delta_action[:end_index_for_plot]
# Filter gripper_change_step_idx
gripper_change_step_idx = np.array(sorted(list(cur_window_indices))).astype(np.int32)
gripper_change_step_idx_to_plot = gripper_change_step_idx[gripper_change_step_idx < end_index_for_plot]
plot_gripper_changes_in_subplots(
delta_action_to_plot,
gripper_change_step_idx_to_plot,
episode_lengths,
num_episodes_to_plot,
gripper_dim,
f"{action_key}_gripper_change"
)
# Convert the set to a numpy array and sort it
gripper_change_step_idx = np.array(sorted(list(all_window_indices))).astype(np.int32)
print(f"Total unique gripper change steps: {len(gripper_change_step_idx)}, Total steps: {len(action_values)}")
dataset.gripper_change_step_idx = gripper_change_step_idx
# set_trace()
return dataset
class Dataset(Protocol[T_co]):
"""Interface for a dataset with random access."""
def __getitem__(self, index: SupportsIndex) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
def __len__(self) -> int:
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
class TransformedDataset(Dataset[T_co]):
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
self._dataset = dataset
self._transform = _transforms.compose(transforms)
def __getitem__(self, index: SupportsIndex) -> T_co:
return self._transform(self._dataset[index])
def __len__(self) -> int:
return len(self._dataset)
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig) -> Dataset:
"""Transform the dataset by applying the data transforms."""
norm_stats = {}
norm_stats = data_config.norm_stats
return TransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
)
class MixtureDataset(Dataset):
"""
A composite dataset that combines multiple datasets, allowing for weighted sampling
and specific handling based on training stage (e.g., pretrain, finetune) and
gripper change detection for augmentation.
This dataset flattens all eligible samples from its constituent datasets and assigns
sampling weights based on configuration and heuristics (e.g., `gripper_aug_ratio`).
"""
def __init__(
self,
datasets: Sequence[Dataset],
datasets_name: Sequence[str],
datasets_meta: Sequence[LeRobotDatasetMetadata],
datasets_weights: Dict[str, float] = None,
gripper_aug_ratio: float = 1.0,
shuffle: bool = True,
):
"""
Initializes the MixtureDataset.
Args:
datasets (Sequence[Dataset]): A list of `Dataset` objects to be combined.
datasets_name (Sequence[str]): A list of names corresponding to each dataset in `datasets`.
datasets_meta (Sequence[LeRobotDatasetMetadata]): Metadata for each dataset,
typically containing `num_episodes`, `num_frames`, `fps`, and `num_indices`.
datasets_weights (Dict[str, float], optional): A dictionary mapping dataset names
to their base sampling weights. If None, equal weights are assumed.
is_eval (bool): If True, the dataset is configured for evaluation, potentially
limiting the number of episodes and disabling shuffling for reproducibility.
num_eval_episodes (int, optional): The number of episodes to select for evaluation.
Only used if `is_eval` is True.
stage (str): The current training stage (e.g., "stage1_pretrain_wm").
This affects how indices are sampled from the underlying datasets.
gripper_aug_ratio (float): A multiplier applied to the weights of samples
that contain a detected gripper change. Useful for augmenting rare events.
shuffle (bool): If True, the flat sample map and sampling weights are shuffled
after initial creation. Ignored if `is_eval` is True.
"""
self.datasets = datasets
self.datasets_name = datasets_name
self.meta = datasets_meta
# Extract total number of episodes and frames for each dataset from metadata.
self.num_episodes = [meta.info['total_episodes'] for meta in datasets_meta]
self.num_frames = [meta.info['total_frames'] for meta in datasets_meta]
# Compute the flattened list of (dataset_idx, sample_idx) pairs.
# This involves sampling indices based on the stage and dataset type.
self._compute_len(False)
# Assign normalized sampling weights to each sample in the flattened map.
self._get_weights(datasets_weights, gripper_aug_ratio)
# For training, ensure the sample map and weights are consistent.
if len(self.flat_sample_map) != len(self.sample_weights):
raise ValueError(
f"Mismatch in flat sample map length ({len(self.flat_sample_map)}) "
f"and sample weights length ({len(self.sample_weights)})."
)
if shuffle:
# Shuffle both the sample map and weights in the same order for training.
# This ensures random access to samples while maintaining their assigned probabilities.
indices = np.random.permutation(len(self.flat_sample_map))
self.flat_sample_map = [self.flat_sample_map[i] for i in indices]
self.sample_weights = self.sample_weights[indices]
def __len__(self) -> int:
"""
Returns the total number of samples in the mixture dataset (after flattening and selection).
This length represents the effective size of the dataset for iteration.
"""
return len(self.flat_sample_map)
def __getitem__(self, index: SupportsIndex):
"""
Retrieves a specific sample from one of the underlying datasets based on the
flattened sample map.
Args:
index (SupportsIndex): The index in the flattened `flat_sample_map` (0 to `len(self) - 1`).
Returns:
Tuple[int, Any]: A tuple containing the original dataset index and the
sample data (dictionary) from that dataset.
Raises:
IndexError: If the provided index is out of bounds for the dataset.
"""
if not (0 <= index < len(self.flat_sample_map)):
raise IndexError(f"Index {index} is out of bounds for the dataset (size: {len(self.flat_sample_map)}).")
# Retrieve the original dataset index and sample index from the flattened map.
dataset_idx, sample_idx = self.flat_sample_map[index]
return self.datasets[dataset_idx][sample_idx]
def _compute_len(self, is_eval: bool = False):
"""
Pre-computes and stores `all_sample_indices`, a list of episode indices sampled
from each constituent dataset. This method prepares the data for `_create_flat_sample_map`.
Args:
is_eval (bool): Flag indicating if indices are being computed for an evaluation dataset.
"""
self.all_sample_indices: List[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = []
for i, (ds, meta) in enumerate(zip(self.datasets, self.meta)):
# Access the underlying LeRobotDataset or MultiLeRobotDataset, bypassing TransformedDataset wrapper.
actual_ds = ds._dataset if isinstance(ds, TransformedDataset) else ds
# Determine the number of indices to sample for this dataset based on the current stage.
# "stage1" typically uses a limited number of indices (`num_indices`), while other stages
# might use all available data or a different strategy.
num_indices = None
if isinstance(actual_ds, MultiLeRobotDataset):
# For MultiLeRobotDataset, iterate through its sub-datasets to get indices.
indices_list_for_multi_ds = []
for sub_ds in actual_ds._datasets:
_from = sub_ds.episode_data_index["from"]
_to = sub_ds.episode_data_index["to"]
indices = self._sample_indices(
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
)
indices_list_for_multi_ds.append(indices)
self.all_sample_indices.append(indices_list_for_multi_ds)
elif isinstance(actual_ds, LeRobotDataset):
# For a single LeRobotDataset.
_from = actual_ds.episode_data_index["from"]
_to = actual_ds.episode_data_index["to"]
indices = self._sample_indices(
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
)
self.all_sample_indices.append(indices)
else:
raise TypeError(f"Unsupported dataset type: {type(actual_ds)}. "
"Expected `LeRobotDataset` or `MultiLeRobotDataset`.")
# After collecting all sampled episode indices, flatten them into `flat_sample_map`.
self.flat_sample_map = self._create_flat_sample_map()
def _create_flat_sample_map(self) -> List[Tuple[int, int]]:
"""
Converts the potentially nested structure of `self.all_sample_indices` (which can be
lists of lists of tensors, or lists of tensors) into a flat list of
`(original_dataset_index, sample_index_within_original_dataset)` tuples.
This flattened map is then used by `__getitem__` to efficiently retrieve samples.
"""
flat_map = []
for dataset_idx, sample_group in enumerate(self.all_sample_indices):
# Case 1: `MultiLeRobotDataset` where `sample_group` is `List[List[torch.Tensor]]`
if isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], list):
for sub_group in sample_group: # Iterate through sub-datasets' index lists
for tensor_of_indices in sub_group: # Iterate through tensors of indices for episodes
for i in range(tensor_of_indices.numel()):
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
# Case 2: `LeRobotDataset` where `sample_group` is `List[torch.Tensor]`
elif isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], torch.Tensor):
for tensor_of_indices in sample_group:
for i in range(tensor_of_indices.numel()):
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
# Case 3: A rare case where `sample_group` might be a single `torch.Tensor` directly
elif isinstance(sample_group, torch.Tensor):
for i in range(sample_group.numel()):
flat_map.append((dataset_idx, sample_group[i].item()))
return flat_map
def _sample_indices(
self,
start: List[int],
end: List[int],
num_frames: Optional[int],
random_pad: bool = False,
is_eval: bool = False,
dataset_name: str = None, # Added for potential future stage-specific logic
) -> List[torch.Tensor]:
"""
Samples indices for episodes based on the current stage and dataset-specific rules.
This function is called per episode to determine which frames to include.
Args:
start (List[int]): List of starting frame indices for each episode.
end (List[int]): List of ending frame indices for each episode.
num_frames (Optional[int]): The target number of frames to sample per episode.
This is primarily used for "stage1" where sampling
a fixed number of frames per episode might be desired.
random_pad (bool): If True, and `frame_count < target_frames`, shorter episodes
will be padded with randomly selected indices from themselves.
is_eval (bool): If True, adjusts indices for evaluation (e.g., shifting by 1 for stage1
to ensure predicted frames are not identical to observed frames).
dataset_name (str): The name of the dataset (for debugging or future dataset-specific sampling rules).
Returns:
List[torch.Tensor]: A list of PyTorch tensors, where each tensor contains the
sampled frame indices for a single episode.
"""
all_indices_for_episodes = []
for _start, _end in zip(start, end):
frame_count = _end - _start # Total frames available in this episode.
target_frames = frame_count
if frame_count >= target_frames:
# If enough frames are available, linearly space the indices to sample `target_frames`.
indices = torch.linspace(_start, _end - 1, steps=target_frames).long()
else:
# If fewer frames than `target_frames` are available.
if random_pad:
# Pad the existing frames with randomly chosen duplicates from the episode.
pad_size = target_frames - frame_count
indices = torch.arange(_start, _end) # All available original indices
# Randomly sample `pad_size` indices from the existing ones.
pad_indices = indices[torch.randint(0, frame_count, (pad_size,))]
indices = torch.cat([indices, pad_indices]) # Combine original and padded indices
indices = indices[torch.randperm(target_frames)] # Randomly permute to mix original and padded.
else:
# If not padding, simply use all available frames.
indices = torch.arange(_start, _end)
all_indices_for_episodes.append(indices)
return all_indices_for_episodes
def _get_weights(self, datasets_weights: Dict[str, float], aug_ratio: float = 1.0):
"""
Assigns normalized sampling weights to each individual sample in the flattened map.
Weights are adjusted based on base dataset weights and `gripper_aug_ratio` for
samples that have a detected gripper change.
Args:
datasets_weights (Dict[str, float]): A dictionary mapping dataset names to their
base sampling weights. If a dataset name is
not found, a default weight of 1.0 is used.
aug_ratio (float): The augmentation ratio (multiplier) to apply to the base weight
for samples where a gripper change is detected.
"""
self.sample_weights: List[float] = []
self.datasets_weight_map: Dict[str, float] = {}
if datasets_weights is None:
num_datasets = len(self.datasets_name)
datasets_weights = {name: 1.0 / num_datasets for name in self.datasets_name}
for idx, ds_name in enumerate(self.datasets_name):
# Access the underlying dataset to get gripper change information.
# It might be wrapped in a TransformedDataset, so we unwrap it.
current_base_dataset = self.datasets[idx]._dataset if isinstance(self.datasets[idx], TransformedDataset) else self.datasets[idx]
base_weight = datasets_weights.get(ds_name, 1.0) # Get base weight for this dataset
individual_weights_for_ds: List[float] = []
# Logic to retrieve `gripper_change_step_idx` and assign weights.
if isinstance(current_base_dataset, MultiLeRobotDataset):
# For MultiLeRobotDataset, iterate through its sub-datasets.
for idj, sub_ds in enumerate(current_base_dataset._datasets):
gripper_change_step_idx = getattr(sub_ds, 'gripper_change_step_idx', None)
if gripper_change_step_idx is not None:
sampled_indices_sub_ds = self.all_sample_indices[idx][idj]
for tensor_of_indices in sampled_indices_sub_ds:
for step_idx in tensor_of_indices.tolist():
if step_idx in gripper_change_step_idx:
individual_weights_for_ds.append(base_weight * aug_ratio)
else:
individual_weights_for_ds.append(base_weight)
elif isinstance(current_base_dataset, LeRobotDataset):
# For a single LeRobotDataset.
gripper_change_step_idx = getattr(current_base_dataset, 'gripper_change_step_idx', None)
if gripper_change_step_idx is not None:
sampled_indices_ds = self.all_sample_indices[idx]
for tensor_of_indices in sampled_indices_ds:
for step_idx in tensor_of_indices.tolist():
if step_idx in gripper_change_step_idx:
individual_weights_for_ds.append(base_weight * aug_ratio)
else:
individual_weights_for_ds.append(base_weight)
if gripper_change_step_idx is None:
print(f"Warning: Gripper change detection not fully supported for dataset type {type(current_base_dataset)}. "
"Assigning uniform weights based on `base_weight` for this dataset.")
num_samples_for_ds_in_flat_map = sum(1 for map_ds_idx, _ in self.flat_sample_map if map_ds_idx == idx)
individual_weights_for_ds.extend([base_weight] * num_samples_for_ds_in_flat_map)
# Accumulate individual weights for all samples and for the dataset's total.
self.sample_weights.extend(individual_weights_for_ds)
self.datasets_weight_map[ds_name] = self.datasets_weight_map.get(ds_name, 0.0) + sum(individual_weights_for_ds)
# Final normalization of all individual sample weights across the entire mixture dataset.
total_sum_of_all_individual_weights = sum(self.sample_weights)
if total_sum_of_all_individual_weights > 0:
self.sample_weights = np.array(self.sample_weights, dtype=np.float32)
self.sample_weights = self.sample_weights / total_sum_of_all_individual_weights
else:
self.sample_weights = np.array([], dtype=np.float32)
# Normalize the `datasets_weight_map` to reflect the effective proportion of each dataset
# in the final sampling distribution.
if total_sum_of_all_individual_weights > 0:
for k in self.datasets_weight_map:
self.datasets_weight_map[k] /= total_sum_of_all_individual_weights
else:
self.datasets_weight_map = {k: 0.0 for k in self.datasets_weight_map} # All weights become zero.
def __str__(self) -> str:
"""
Returns a formatted string representation of the MixtureDataset,
showing the effective sampling weights and dataset lengths.
"""
# Define ANSI escape codes for colored and bold text.
RESET = "\033[0m"
BOLD = "\033[1m"
CYAN = "\033[96m"
YELLOW = "\033[93m"
GREEN = "\033[92m"
MAGENTA = "\033[95m"
# Determine the maximum key length for consistent formatting.
max_key_len = max(len(k) for k in self.datasets_weight_map.keys()) + 2 if self.datasets_weight_map else 20
# Build the lines of the string representation.
lines = [
f"{BOLD}{MAGENTA}######################################### 👈 Dataset Weight Map: ########################################{RESET}"
]
# Add individual dataset information: name, number of samples, and effective weight.
for idx, (name, weight) in enumerate(self.datasets_weight_map.items()):
# Use `len(self.datasets[idx])` to get the number of samples in each transformed dataset.
# Formatting to 2 decimal places for weight and 0 for sample count.
lines.append(f"{CYAN}{name:<{max_key_len}} : {len(self.datasets[idx]):>18.0f} ({weight*100:>.2f}%){RESET}")
# Add a separator line.
separator_length = len(lines[0]) - len(BOLD) - len(MAGENTA) - len(RESET) + 1
lines.append("-" * separator_length)
# Add total episodes summary.
lines.append(f"{CYAN}{'Total Episodes':<{max_key_len}}{RESET} : {YELLOW}{sum(self.num_episodes):>18.0f}{RESET}")
# Add the closing border, matching the length of the separator.
lines.append(f"{BOLD}{MAGENTA}{'#' * separator_length}{RESET}")
return "\n".join(lines)
def create_mixture_dataset(
data_configs_list,
action_horizon,
model_config,
):
all_datasets = []
all_datasets_name = []
all_datasets_meta = []
all_datasets_weight = {}
for ds_configs in data_configs_list:
for ds_config in ds_configs:
repo_dir = ds_config.repo_dir
task_id = ds_config.task_id
subtask_id = ds_config.subtask_id
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
episodes = list(dataset_meta.episodes_stats.keys())
if ds_config.data_ratio < 1.0:
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
logging.info(f"sub_length: {sub_length}")
indices = np.random.choice(len(episodes), sub_length, replace=False)
episodes = [episodes[i] for i in indices]
print(f"downsample ratio: {ds_config.downsample_ratio}")
dataset = LeRobotDataset(
episodes=episodes,
repo_id=root_path,
root=root_path,
delta_timestamps={
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
},
)
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
gripper_aug_config = ds_config.gripper_aug_config
dataset = detect_gripper_change_step(
dataset,
select_actions=gripper_aug_config["gripper_action_keys"],
gripper_dim=gripper_aug_config["gripper_dim"],
threshold_method=gripper_aug_config["gripper_threshold_method"],
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
min_threshold=gripper_aug_config["gripper_min_threshold"],
max_threshold=gripper_aug_config["gripper_max_threshold"],
)
dataset = transform_dataset(dataset, ds_config)
dataset_name = root_path
dataset_weight = ds_config.weight
all_datasets.append(dataset)
all_datasets_name.append(dataset_name)
all_datasets_meta.append(dataset_meta)
all_datasets_weight[dataset_name] = dataset_weight
mixture_dataset = MixtureDataset(
all_datasets,
all_datasets_name,
all_datasets_meta,
all_datasets_weight,
gripper_aug_ratio=10.0,
)
return mixture_dataset
def create_mixture_dataset_no_transform(
data_configs_list,
action_horizon,
model_config
):
all_datasets = []
all_datasets_name = []
all_datasets_meta = []
all_datasets_weight = {}
for ds_configs in data_configs_list:
for ds_config in ds_configs:
repo_dir = ds_config.repo_dir
task_id = ds_config.task_id
subtask_id = ds_config.subtask_id
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
episodes = list(dataset_meta.episodes_stats.keys())
if ds_config.data_ratio < 1.0:
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
episodes = episodes[:sub_length]
dataset = LeRobotDataset(
episodes=episodes,
repo_id=root_path,
root=root_path,
delta_timestamps={
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
},
)
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
gripper_aug_config = ds_config.gripper_aug_config
dataset = detect_gripper_change_step(
dataset,
select_actions=gripper_aug_config["gripper_action_keys"],
gripper_dim=gripper_aug_config["gripper_dim"],
threshold_method=gripper_aug_config["gripper_threshold_method"],
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
min_threshold=gripper_aug_config["gripper_min_threshold"],
max_threshold=gripper_aug_config["gripper_max_threshold"],
)
dataset_name = root_path
dataset_weight = ds_config.weight
all_datasets.append(dataset)
all_datasets_name.append(dataset_name)
all_datasets_meta.append(dataset_meta)
all_datasets_weight[dataset_name] = dataset_weight
mixture_dataset = MixtureDataset(
all_datasets,
all_datasets_name,
all_datasets_meta,
all_datasets_weight,
gripper_aug_ratio=10.0,
)
return mixture_dataset
def create_mixture_dataset_calculate_norm_stats(
data_configs_list,
action_horizon,
model_config
):
all_datasets = []
all_datasets_name = []
all_datasets_meta = []
all_datasets_weight = {}
for ds_config in data_configs_list:
repo_dir = ds_config.repo_dir
task_id = ds_config.task_id
subtask_id = ds_config.subtask_id
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
episodes = list(dataset_meta.episodes_stats.keys())
if ds_config.data_ratio < 1.0:
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
episodes = episodes[:sub_length]
dataset = LeRobotDataset(
episodes=episodes,
repo_id=root_path,
root=root_path,
delta_timestamps={
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
},
load_video=False,
)
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
gripper_aug_config = ds_config.gripper_aug_config
dataset = detect_gripper_change_step(
dataset,
select_actions=gripper_aug_config["gripper_action_keys"],
gripper_dim=gripper_aug_config["gripper_dim"],
threshold_method=gripper_aug_config["gripper_threshold_method"],
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
min_threshold=gripper_aug_config["gripper_min_threshold"],
max_threshold=gripper_aug_config["gripper_max_threshold"],
)
dataset_name = root_path
dataset_weight = ds_config.weight
all_datasets.append(dataset)
all_datasets_name.append(dataset_name)
all_datasets_meta.append(dataset_meta)
all_datasets_weight[dataset_name] = dataset_weight
mixture_dataset = MixtureDataset(
all_datasets,
all_datasets_name,
all_datasets_meta,
all_datasets_weight,
gripper_aug_ratio=10.0,
)
return mixture_dataset

View File

@@ -0,0 +1,123 @@
import dataclasses
from typing import Protocol, runtime_checkable
import jax.numpy as jnp
import optax
import openpi.shared.array_typing as at
@runtime_checkable
class LRScheduleConfig(Protocol):
def create(self) -> optax.Schedule: ...
@dataclasses.dataclass(frozen=True)
class CosineDecaySchedule(LRScheduleConfig):
"""Cosine decay schedule with warmup."""
warmup_steps: int = 1_000
peak_lr: float = 2.5e-5
decay_steps: int = 30_000
decay_lr: float = 2.5e-6
def create(self) -> optax.Schedule:
return optax.warmup_cosine_decay_schedule(
init_value=self.peak_lr / (self.warmup_steps + 1),
peak_value=self.peak_lr,
warmup_steps=self.warmup_steps,
decay_steps=self.decay_steps,
end_value=self.decay_lr,
)
@dataclasses.dataclass(frozen=True)
class RsqrtDecaySchedule(LRScheduleConfig):
"""Inverse square root decay schedule with warmup."""
warmup_steps: int = 1_000
peak_lr: float = 5e-5
timescale: float = 10_000
def create(self) -> optax.Schedule:
return optax.join_schedules(
[
optax.linear_schedule(
init_value=self.peak_lr / (self.warmup_steps + 1),
end_value=self.peak_lr,
transition_steps=self.warmup_steps,
),
lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale),
],
[self.warmup_steps],
)
@dataclasses.dataclass(frozen=True)
class WarmupConstantSchedule(LRScheduleConfig):
"""Warmup constant schedule with warmup."""
warmup_steps: int = 2_000
peak_lr: float = 5e-5
def create(self) -> optax.Schedule:
return optax.warmup_constant_schedule(
init_value=self.peak_lr / (self.warmup_steps + 1),
peak_value=self.peak_lr,
warmup_steps=self.warmup_steps,
)
@runtime_checkable
class OptimizerConfig(Protocol):
def create(
self,
lr: optax.ScalarOrSchedule,
weight_decay_mask: at.PyTree | None = None,
) -> optax.GradientTransformation: ...
@dataclasses.dataclass(frozen=True)
class AdamW(OptimizerConfig):
"""AdamW optimizer."""
b1: float = 0.9
b2: float = 0.95
eps: float = 1e-8
# Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value.
weight_decay: float = 1e-10
clip_gradient_norm: float = 1.0
def create(
self,
lr: optax.ScalarOrSchedule,
weight_decay_mask: at.PyTree | None = None,
) -> optax.GradientTransformation:
tx = optax.adamw(
lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask
)
return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx)
@dataclasses.dataclass(frozen=True)
class SGD(OptimizerConfig):
"""SGD optimizer."""
lr: float = 5e-5
momentum: float = 0.9
nesterov: bool = False
def create(
self,
lr: optax.ScalarOrSchedule,
weight_decay_mask: at.PyTree | None = None,
) -> optax.GradientTransformation:
assert weight_decay_mask is None, "Weight decay is not supported for SGD"
return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov)
def create_optimizer(
optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None
) -> optax.GradientTransformation:
lr = lr_schedule.create()
return optimizer.create(lr, weight_decay_mask=weight_decay_mask)

View File

@@ -0,0 +1,102 @@
import contextlib
import logging
import jax
import numpy as np
BATCH_AXIS = "batch"
FSDP_AXIS = "fsdp"
# In FSDP, we shard the data across both the batch and FSDP axes.
DATA_AXIS = (BATCH_AXIS, FSDP_AXIS)
class _MeshState:
active_mesh: jax.sharding.Mesh | None = None
def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
if jax.device_count() % num_fsdp_devices != 0:
raise ValueError(
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}."
)
mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)
return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS))
@contextlib.contextmanager
def set_mesh(mesh: jax.sharding.Mesh):
"""Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a
custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used
in `activation_sharding_constraint` below."""
if _MeshState.active_mesh is not None:
raise ValueError("Cannot nest set_mesh context managers.")
_MeshState.active_mesh = mesh
try:
yield
finally:
_MeshState.active_mesh = None
def activation_sharding_constraint(pytree):
if _MeshState.active_mesh is None:
return pytree
return jax.lax.with_sharding_constraint(
pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS))
)
def fsdp_sharding(
pytree,
mesh: jax.sharding.Mesh,
*,
min_size_mbytes: int = 4, # 4 MiB
log: bool = False,
):
"""Apply FSDP sharding to a pytree of arrays based on the mesh shape.
Args:
pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr)
will be considered for sharding.
mesh: The mesh being used for applying sharding on to pytree.
min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this
will be replicated.
log: If true, will log the sharding decisions for arrays that are being considered for sharding.
Returns:
The sharded pytree.
"""
min_size_bytes = min_size_mbytes * 2**20
def _shard_arr(kp, array: jax.ShapeDtypeStruct):
# if fsdp is not actually going to be used, replicate everything to avoid extraneous logging
if mesh.shape[FSDP_AXIS] == 1:
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# replicate scalar and vector arrays
if not hasattr(array, "shape"):
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
if len(array.shape) < 2:
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# replicate small arrays
if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes:
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension
axes = np.argsort(array.shape)[::-1]
spec = [None] * len(axes)
for i in axes:
if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:
if log:
logging.info(
f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}"
)
spec[i] = FSDP_AXIS
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
# replicate if no valid sharding was found
if log:
logging.warning(
f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}"
)
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
return jax.tree_util.tree_map_with_path(_shard_arr, pytree)

View File

@@ -0,0 +1,38 @@
from collections.abc import Callable
from typing import Any
from flax import nnx
from flax import struct
import jax
import optax
from openpi.models import model as _model
from openpi.shared import array_typing as at
@at.typecheck
@struct.dataclass
class TrainState:
step: at.Int[at.ArrayLike, ""]
params: nnx.State
model_def: nnx.GraphDef[_model.BaseModel]
opt_state: optax.OptState
tx: optax.GradientTransformation = struct.field(pytree_node=False)
ema_decay: float | None = struct.field(pytree_node=False)
ema_params: nnx.State | None = None
@at.typecheck
def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:
"""Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert
the leaf values to more meaningful strings.
"""
tree, _ = jax.tree_util.tree_flatten_with_path(tree)
return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree)
@at.typecheck
def array_tree_to_info(tree: at.PyTree) -> str:
"""Converts a PyTree of arrays into a human-readable string for logging."""
return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}")

View File

@@ -0,0 +1,103 @@
import dataclasses
import logging
import re
from typing import Protocol, runtime_checkable
import flax.traverse_util
import numpy as np
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.shared.download as download
from pathlib import Path
logger = logging.getLogger(__name__)
@runtime_checkable
class WeightLoader(Protocol):
def load(self, params: at.Params) -> at.Params:
"""Loads the model weights.
Args:
params: Parameters of the model. This is a nested structure of array-like objects that
represent the model's parameters.
Returns:
Loaded parameters. The structure must be identical to `params`. If returning a subset of
the parameters the loader must merge the loaded parameters with `params`.
"""
@dataclasses.dataclass(frozen=True)
class NoOpWeightLoader(WeightLoader):
def load(self, params: at.Params) -> at.Params:
return params
@dataclasses.dataclass(frozen=True)
class CheckpointWeightLoader(WeightLoader):
"""Loads an entire set of weights from a checkpoint.
Compatible with:
trained checkpoints:
example: "./checkpoints/<config>/<exp>/<step>/params"
released checkpoints:
example: "gs://openpi-assets/checkpoints/<model>/params"
"""
params_path: str
def load(self, params: at.Params) -> at.Params:
# We are loading np.ndarray and relying on the training code to properly convert and shard the params.
loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray)
# Add all missing LoRA weights.
return _merge_params(loaded_params, params, missing_regex=".*lora.*")
@dataclasses.dataclass(frozen=True)
class PaliGemmaWeightLoader(WeightLoader):
"""Loads weights from the official PaliGemma checkpoint.
This will overwrite existing weights with similar names while keeping all extra weights intact.
This allows us to support the action expert which is used by the Pi0 model.
"""
params_path: str
def load(self, params: at.Params) -> at.Params:
path = Path(self.params_path)
with path.open("rb") as f:
flat_params = dict(np.load(f, allow_pickle=False))
loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]}
# Add all missing weights.
return _merge_params(loaded_params, params, missing_regex=".*")
def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:
"""Merges the loaded parameters with the reference parameters.
Args:
loaded_params: The parameters to merge.
params: The reference parameters.
missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.
Returns:
A new dictionary with the merged parameters.
"""
flat_ref = flax.traverse_util.flatten_dict(params, sep="/")
flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/")
# First, take all weights that are a subset of the reference weights.
result = {}
for k, v in flat_loaded.items():
if k in flat_ref:
result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
flat_loaded.clear()
# Then, merge any missing weights as defined by the missing regex.
pattern = re.compile(missing_regex)
for k in {k for k in flat_ref if pattern.fullmatch(k)}:
if k not in result:
result[k] = flat_ref[k]
return flax.traverse_util.unflatten_dict(result, sep="/")

View File

@@ -0,0 +1,597 @@
from collections.abc import Callable, Mapping, Sequence
import dataclasses
import re
from typing import Protocol, TypeAlias, TypeVar, runtime_checkable
import flax.traverse_util as traverse_util
import jax
import numpy as np
from openpi_client import image_tools
from openpi.models import tokenizer as _tokenizer
from openpi.shared import array_typing as at
from openpi.shared import normalize as _normalize
from scipy.spatial.transform import Rotation as R
from pdb import set_trace
DataDict: TypeAlias = at.PyTree
NormStats: TypeAlias = _normalize.NormStats
T = TypeVar("T")
S = TypeVar("S")
@runtime_checkable
class DataTransformFn(Protocol):
def __call__(self, data: DataDict) -> DataDict:
"""Apply transformation to the data.
Args:
data: The data to apply the transform to. This is a possibly nested dictionary that contains
unbatched data elements. Each leaf is expected to be a numpy array. Using JAX arrays is allowed
but not recommended since it may result in extra GPU memory usage inside data loader worker
processes.
Returns:
The transformed data. Could be the input `data` that was modified in place, or a new data structure.
"""
@dataclasses.dataclass(frozen=True)
class Group:
"""A group of transforms."""
# Transforms that are applied to the model input data.
inputs: Sequence[DataTransformFn] = ()
# Transforms that are applied to the model output data.
outputs: Sequence[DataTransformFn] = ()
def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Sequence[DataTransformFn] = ()) -> "Group":
"""Append transforms to the group and return a new group.
Args:
inputs: Appended to the *end* of the current input transforms.
outputs: Appended to the *beginning* of the current output transforms.
Returns:
A new group with the appended transforms.
"""
return Group(inputs=(*self.inputs, *inputs), outputs=(*outputs, *self.outputs))
@dataclasses.dataclass(frozen=True)
class CompositeTransform(DataTransformFn):
"""A composite transform that applies a sequence of transforms in order."""
transforms: Sequence[DataTransformFn]
def __call__(self, data: DataDict) -> DataDict:
for transform in self.transforms:
data = transform(data)
return data
def compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn:
"""Compose a sequence of transforms into a single transform."""
return CompositeTransform(transforms)
@dataclasses.dataclass(frozen=True)
class RepackTransform(DataTransformFn):
"""Repacks an input dictionary into a new dictionary.
Repacking is defined using a dictionary where the keys are the new keys and the values
are the flattened paths to the old keys. We use '/' as the separator during flattening.
Example:
{
"images": {
"cam_high": "observation.images.top",
"cam_low": "observation.images.bottom",
},
"state": "observation.state",
"actions": "action",
}
"""
structure: at.PyTree[str]
def __call__(self, data: DataDict) -> DataDict:
flat_item = flatten_dict(data)
return jax.tree.map(lambda k: flat_item[k], self.structure)
@dataclasses.dataclass(frozen=True)
class ReTransform(DataTransformFn):
"""Repacks an input dictionary into a new dictionary.
Repacking is defined using a dictionary where the keys are the new keys and the values
are the flattened paths to the old keys. We use '/' as the separator during flattening.
Example:
{
"images": {
"cam_high": "observation.images.top",
"cam_low": "observation.images.bottom",
},
"state": "observation.state",
"actions": "action",
}
"""
structure: at.PyTree[str]
def __call__(self, data: DataDict) -> DataDict:
flat_item = flatten_dict(data)
import pdb
pdb.set_trace()
return jax.tree.map(lambda k: flat_item[k], self.structure)
@dataclasses.dataclass(frozen=True)
class InjectDefaultPrompt(DataTransformFn):
prompt: str | None
def __call__(self, data: DataDict) -> DataDict:
if self.prompt is not None and "prompt" not in data:
data["prompt"] = np.asarray(self.prompt)
return data
@dataclasses.dataclass(frozen=True)
class Normalize(DataTransformFn):
norm_stats: at.PyTree[NormStats] | None
# If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
use_quantiles: bool = False
# If true, will raise an error if any of the keys in the norm stats are not present in the data.
strict: bool = False
def __post_init__(self):
if self.norm_stats is not None and self.use_quantiles:
_assert_quantile_stats(self.norm_stats)
def __call__(self, data: DataDict) -> DataDict:
if self.norm_stats is None:
return data
return apply_tree(
data,
self.norm_stats,
self._normalize_quantile if self.use_quantiles else self._normalize,
strict=self.strict,
)
def _normalize(self, x, stats: NormStats):
mean, std = stats.mean[..., : x.shape[-1]], stats.std[..., : x.shape[-1]]
return (x - mean) / (std + 1e-6)
def _normalize_quantile(self, x, stats: NormStats):
assert stats.q01 is not None
assert stats.q99 is not None
q01, q99 = stats.q01[..., : x.shape[-1]], stats.q99[..., : x.shape[-1]]
return (x - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0
@dataclasses.dataclass(frozen=True)
class Unnormalize(DataTransformFn):
norm_stats: at.PyTree[NormStats] | None
# If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
use_quantiles: bool = False
def __post_init__(self):
if self.norm_stats is not None and self.use_quantiles:
_assert_quantile_stats(self.norm_stats)
def __call__(self, data: DataDict) -> DataDict:
if self.norm_stats is None:
return data
# Make sure that all the keys in the norm stats are present in the data.
return apply_tree(
data,
self.norm_stats,
self._unnormalize_quantile if self.use_quantiles else self._unnormalize,
strict=True,
)
def _unnormalize(self, x, stats: NormStats):
mean = pad_to_dim(stats.mean, x.shape[-1], axis=-1, value=0.0)
std = pad_to_dim(stats.std, x.shape[-1], axis=-1, value=1.0)
return x * (std + 1e-6) + mean
def _unnormalize_quantile(self, x, stats: NormStats):
assert stats.q01 is not None
assert stats.q99 is not None
q01, q99 = stats.q01, stats.q99
if (dim := q01.shape[-1]) < x.shape[-1]:
return np.concatenate([(x[..., :dim] + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01, x[..., dim:]], axis=-1)
return (x + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01
@dataclasses.dataclass(frozen=True)
class ResizeImages(DataTransformFn):
height: int
width: int
def __call__(self, data: DataDict) -> DataDict:
data["image"] = {k: image_tools.resize_with_pad(v, self.height, self.width) for k, v in data["image"].items()}
return data
@dataclasses.dataclass(frozen=True)
class SubsampleActions(DataTransformFn):
stride: int
def __call__(self, data: DataDict) -> DataDict:
data["actions"] = data["actions"][:: self.stride]
return data
@dataclasses.dataclass(frozen=True)
class DeltaActions(DataTransformFn):
"""Repacks absolute actions into delta action space."""
# Boolean mask for the action dimensions to be repacked into delta action space. Length
# can be smaller than the actual number of dimensions. If None, this transform is a no-op.
# See `make_bool_mask` for more details.
mask: Sequence[bool] | None
def __call__(self, data: DataDict) -> DataDict:
if "actions" not in data or self.mask is None:
return data
state, actions = data["state"], data["actions"]
mask = np.asarray(self.mask)
dims = mask.shape[-1]
actions[..., :dims] -= np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2)
data["actions"] = actions
return data
@dataclasses.dataclass(frozen=True)
class DeltaActionsPose(DataTransformFn):
"""Repacks absolute actions into delta action space."""
# Boolean mask for the action dimensions to be repacked into delta action space. Length
# can be smaller than the actual number of dimensions. If None, this transform is a no-op.
# See `make_bool_mask` for more details.
mask: Sequence[bool] | None
def __call__(self, data: DataDict) -> DataDict:
# set_trace()
if "actions" not in data or self.mask is None:
return data
pose, actions = data["pose"], data["actions"]
mask = np.asarray(self.mask)
dims = mask.shape[-1]
act = actions[..., :dims]
st = pose[..., :dims]
pose_mask = mask[:6]
if np.any(pose_mask):
pose_action = act[..., :6]
pose_state = st[..., :6]
if pose_action.ndim == 2:
rel_list = []
for i in range(pose_action.shape[0]):
rel_list.append(relative_pose(pose_action[i], pose_state))
rel_pose = np.stack(rel_list, axis=0)
else:
raise ValueError("pose_action must be dim 2")
act[..., :6] = rel_pose
data["actions"][..., :dims] = act
del data["pose"]
return data
@dataclasses.dataclass(frozen=True)
class AbsoluteActions(DataTransformFn):
"""Repacks delta actions into absolute action space."""
# Boolean mask for the action dimensions to be repacked into absolute action space. Length
# can be smaller than the actual number of dimensions. If None, this transform is a no-op.
# See `make_bool_mask` for more details.
mask: Sequence[bool] | None
def __call__(self, data: DataDict) -> DataDict:
if "actions" not in data or self.mask is None:
return data
state, actions = data["state"], data["actions"]
mask = np.asarray(self.mask)
dims = mask.shape[-1]
actions[..., :dims] += np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2)
data["actions"] = actions
return data
@dataclasses.dataclass(frozen=True)
class AbsoluteActionsPose:
"""Convert relative pose actions back into absolute pose actions."""
mask: Sequence[bool] | None
def __call__(self, data):
if "actions" not in data or "pose" not in data or self.mask is None:
return data
actions = data["actions"] # (T, D)
pose = data["pose"] # (D,)
mask = np.asarray(self.mask)
dims = mask.shape[-1]
act = actions[..., :dims]
st = pose[..., :dims]
pose_mask = mask[:6]
if np.any(pose_mask):
pose_action = act[..., :6]
pose_state = st[..., :6]
abs_list = []
for i in range(pose_action.shape[0]):
abs_list.append(absolute_pose(pose_action[i], pose_state))
abs_pose = np.stack(abs_list, axis=0)
act[..., :6] = abs_pose
data["actions"][..., :dims] = act
return data
@dataclasses.dataclass(frozen=True)
class TokenizePrompt(DataTransformFn):
tokenizer: _tokenizer.PaligemmaTokenizer
discrete_state_input: bool = False
def __call__(self, data: DataDict) -> DataDict:
if (prompt := data.pop("prompt", None)) is None:
raise ValueError("Prompt is required")
if self.discrete_state_input:
if (state := data.get("state", None)) is None:
raise ValueError("State is required.")
else:
state = None
if not isinstance(prompt, str):
prompt = prompt.item()
tokens, token_masks = self.tokenizer.tokenize(prompt, state)
return {**data, "tokenized_prompt": tokens, "tokenized_prompt_mask": token_masks}
@dataclasses.dataclass(frozen=True)
class TokenizeFASTInputs(DataTransformFn):
tokenizer: _tokenizer.FASTTokenizer
def __call__(self, data: DataDict) -> DataDict:
if (prompt := data.pop("prompt", None)) is None:
raise ValueError("Prompt is required")
if not isinstance(prompt, str):
prompt = prompt.item()
state, actions = data["state"], data.get("actions")
tokens, token_mask, ar_mask, loss_mask = self.tokenizer.tokenize(prompt, state, actions)
return {
**data,
"tokenized_prompt": tokens,
"tokenized_prompt_mask": token_mask,
"token_ar_mask": ar_mask,
"token_loss_mask": loss_mask,
}
@dataclasses.dataclass(frozen=True)
class ExtractFASTActions(DataTransformFn):
tokenizer: _tokenizer.FASTTokenizer
action_horizon: int
action_dim: int
def __call__(self, data: DataDict) -> DataDict:
if "actions" not in data:
return data
# Model outputs are saved in "actions", but for FAST models they represent tokens.
tokens = data.pop("actions")
actions = self.tokenizer.extract_actions(tokens.astype(np.int32), self.action_horizon, self.action_dim)
return {
**data,
"actions": actions,
}
@dataclasses.dataclass(frozen=True)
class PromptFromLeRobotTask(DataTransformFn):
"""Extracts a prompt from the current LeRobot dataset task."""
# Contains the LeRobot dataset tasks (dataset.meta.tasks).
tasks: dict[int, str]
def __call__(self, data: DataDict) -> DataDict:
if "task_index" not in data:
raise ValueError('Cannot extract prompt without "task_index"')
task_index = int(data["task_index"])
if (prompt := self.tasks.get(task_index)) is None:
raise ValueError(f"{task_index=} not found in task mapping: {self.tasks}")
return {**data, "prompt": prompt}
@dataclasses.dataclass(frozen=True)
class PadStatesAndActions(DataTransformFn):
"""Zero-pads states and actions to the model action dimension."""
model_action_dim: int
def __call__(self, data: DataDict) -> DataDict:
data["state"] = pad_to_dim(data["state"], self.model_action_dim, axis=-1)
if "actions" in data:
data["actions"] = pad_to_dim(data["actions"], self.model_action_dim, axis=-1)
return data
def flatten_dict(tree: at.PyTree) -> dict:
"""Flatten a nested dictionary. Uses '/' as the separator."""
return traverse_util.flatten_dict(tree, sep="/")
def unflatten_dict(tree: dict) -> at.PyTree:
"""Unflatten a flattened dictionary. Assumes that '/' was used as a separator."""
return traverse_util.unflatten_dict(tree, sep="/")
def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree:
"""Transform the structure of a nested dictionary using a set of patterns.
The transformation is defined using the `patterns` dictionary. The keys are the
input keys that should be matched and the values are the new names inside the output
dictionary. If the value is None, the input key is removed.
Both keys and values should represent flattened paths using '/' as the separator.
Keys can be regular expressions and values can include backreferences to the
matched groups (see `re.sub` for more details). Note that the regular expression
must match the entire key.
The order inside the `patterns` dictionary is important. Only the first pattern that
matches the input key will be used.
See unit tests for more examples.
Args:
patterns: A mapping from old keys to new keys.
tree: The nested dictionary to transform.
Returns:
The transformed nested dictionary.
"""
data = flatten_dict(tree)
# Compile the patterns.
compiled = {re.compile(k): v for k, v in patterns.items()}
output = {}
for k in data:
for pattern, repl in compiled.items():
if pattern.fullmatch(k):
new_k = pattern.sub(repl, k, count=1) if repl is not None else None
break
else:
# Use the original key if no match is found.
new_k = k
if new_k is not None:
if new_k in output:
raise ValueError(f"Key '{new_k}' already exists in output")
output[new_k] = data[k]
# Validate the output structure to make sure that it can be unflattened.
names = sorted(output)
for i in range(len(names) - 1):
name, next_name = names[i : i + 2]
if next_name.startswith(name + "/"):
raise ValueError(f"Leaf '{name}' aliases a node of '{next_name}'")
return unflatten_dict(output)
def apply_tree(
tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False
) -> at.PyTree[T]:
tree = flatten_dict(tree)
selector = flatten_dict(selector)
def transform(k: str, v: T) -> T:
if k in selector:
return fn(v, selector[k])
return v
if strict:
for k in selector:
if k not in tree:
raise ValueError(f"Selector key {k} not found in tree")
return unflatten_dict({k: transform(k, v) for k, v in tree.items()})
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
"""Pad an array to the target dimension with zeros along the specified axis."""
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width, constant_values=value)
return x
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
"""Make a boolean mask for the given dimensions.
Example:
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
make_bool_mask(2, 0, 2) == (True, True, True, True)
Args:
dims: The dimensions to make the mask for.
Returns:
A tuple of booleans.
"""
result = []
for dim in dims:
if dim > 0:
result.extend([True] * (dim))
else:
result.extend([False] * (-dim))
return tuple(result)
def _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None:
for k, v in flatten_dict(norm_stats).items():
if v.q01 is None or v.q99 is None:
raise ValueError(
f"quantile stats must be provided if use_quantile_norm is True. Key {k} is missing q01 or q99."
)
def pose6d_to_pose(pose6d, degrees=False):
"""
pose6d: (6,)
return: (4, 4)
"""
pose = np.eye(4)
pos = pose6d[:3]
euler = pose6d[3:]
R_mat = R.from_euler("xyz", euler, degrees=degrees).as_matrix()
pose[:3, :3] = R_mat
pose[:3, 3] = pos
return pose
def pose_to_6d(pose, degrees=False):
"""
pose: (4, 4)
return: (6,)
"""
pos = pose[:3, 3]
rot = pose[:3, :3]
euler = R.from_matrix(rot).as_euler("xyz", degrees=degrees)
return np.concatenate([pos, euler], axis=0)
def relative_pose(pose_action, pose_state):
pose_a = pose6d_to_pose(pose_action, degrees=False)
pose_s = pose6d_to_pose(pose_state, degrees=False)
rel = np.linalg.inv(pose_s) @ pose_a
return pose_to_6d(rel, degrees=False)
def absolute_pose(pose_delta, pose_state):
pose_d = pose6d_to_pose(pose_delta, degrees=False)
pose_s = pose6d_to_pose(pose_state, degrees=False)
abs_pose = pose_s @ pose_d
return pose_to_6d(abs_pose, degrees=False)

View File

@@ -0,0 +1,121 @@
import numpy as np
import pytest
import openpi.models.tokenizer as _tokenizer
import openpi.transforms as _transforms
def test_repack_transform():
transform = _transforms.RepackTransform(
structure={
"a": {"b": "b/c"},
"d": "e/f",
}
)
item = {"b": {"c": 1}, "e": {"f": 2}}
assert transform(item) == {"a": {"b": 1}, "d": 2}
def test_delta_actions():
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
transform = _transforms.DeltaActions(mask=[False, True])
transformed = transform(item)
assert np.all(transformed["state"] == np.array([1, 2, 3]))
assert np.all(transformed["actions"] == np.array([[3, 2, 5], [5, 4, 7]]))
def test_delta_actions_noop():
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
# No-op when the mask is disabled.
transform = _transforms.DeltaActions(mask=None)
assert transform(item) is item
# No-op when there are no actions in the input.
del item["actions"]
transform = _transforms.DeltaActions(mask=[True, False])
assert transform(item) is item
def test_absolute_actions():
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
transform = _transforms.AbsoluteActions(mask=[False, True])
transformed = transform(item)
assert np.all(transformed["state"] == np.array([1, 2, 3]))
assert np.all(transformed["actions"] == np.array([[3, 6, 5], [5, 8, 7]]))
def test_absolute_actions_noop():
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
# No-op when the mask is disabled.
transform = _transforms.AbsoluteActions(mask=None)
assert transform(item) is item
# No-op when there are no actions in the input.
del item["actions"]
transform = _transforms.AbsoluteActions(mask=[True, False])
assert transform(item) is item
def test_make_bool_mask():
assert _transforms.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
assert _transforms.make_bool_mask(2, 0, 2) == (True, True, True, True)
def test_tokenize_prompt():
tokenizer = _tokenizer.PaligemmaTokenizer(max_len=12)
transform = _transforms.TokenizePrompt(tokenizer)
data = transform({"prompt": "Hello, world!"})
tok_prompt, tok_mask = tokenizer.tokenize("Hello, world!")
assert np.allclose(tok_prompt, data["tokenized_prompt"])
assert np.allclose(tok_mask, data["tokenized_prompt_mask"])
def test_tokenize_no_prompt():
transform = _transforms.TokenizePrompt(_tokenizer.PaligemmaTokenizer())
with pytest.raises(ValueError, match="Prompt is required"):
transform({})
def test_transform_dict():
# Rename and remove keys.
input = {"a": {"b": 1, "c": 2}}
output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input)
assert output == {"a": {"c": 1}}
# Raises and error since the renamed key conflicts with an existing key.
with pytest.raises(ValueError, match="Key 'a/c' already exists in output"):
_transforms.transform_dict({"a/b": "a/c"}, input)
# Full match is required and so nothing will be removed.
input = {"a": {"b": 1, "c": 2}}
output = _transforms.transform_dict({"a": None}, input)
assert output == input
# The regex matches the entire key and so the entire input will be removed.
input = {"a": {"b": 1, "c": 2}}
output = _transforms.transform_dict({"a.+": None}, input)
assert output == {}
# Replace keys using backreferences. All leaves named 'c' are replaced with 'd'.
input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}}
output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input)
assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}}
def test_extract_prompt_from_task():
transform = _transforms.PromptFromLeRobotTask({1: "Hello, world!"})
data = transform({"task_index": 1})
assert data["prompt"] == "Hello, world!"
with pytest.raises(ValueError, match="task_index=2 not found in task mapping"):
transform({"task_index": 2})