multi-node openpi commit
This commit is contained in:
0
policy/openpi-InternData-A1/src/openpi/__init__.py
Normal file
0
policy/openpi-InternData-A1/src/openpi/__init__.py
Normal file
17
policy/openpi-InternData-A1/src/openpi/conftest.py
Normal file
17
policy/openpi-InternData-A1/src/openpi/conftest.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
|
||||
import pynvml
|
||||
import pytest
|
||||
|
||||
|
||||
def set_jax_cpu_backend_if_no_gpu() -> None:
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
pynvml.nvmlShutdown()
|
||||
except pynvml.NVMLError:
|
||||
# No GPU found.
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
set_jax_cpu_backend_if_no_gpu()
|
||||
459
policy/openpi-InternData-A1/src/openpi/models/gemma.py
Normal file
459
policy/openpi-InternData-A1/src/openpi/models/gemma.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gemma adaptation for Pi, taken from big_vision.
|
||||
|
||||
We follow this einsum axis naming convention:
|
||||
B: batch
|
||||
T: query length
|
||||
S: k/v length
|
||||
N: num query heads
|
||||
K: num k/v heads
|
||||
G: num query heads per k/v head
|
||||
H: head dim
|
||||
D: d_model ("features")
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.training.sharding as sharding
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257_152
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant: Variant) -> Config:
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "dummy":
|
||||
return Config(
|
||||
width=64,
|
||||
depth=4,
|
||||
mlp_dim=128,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=16,
|
||||
)
|
||||
if variant == "gemma_300m":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
|
||||
)
|
||||
if variant == "gemma_300m_lora":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x, cond):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
if cond is None:
|
||||
# regular RMSNorm
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype), None # return in original dtype
|
||||
|
||||
# adaptive RMSNorm
|
||||
modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
|
||||
scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
|
||||
normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
|
||||
return normed_inputs.astype(dtype), gate
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.normal(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
configs: Sequence[Config]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, positions, attn_mask, kv_cache):
|
||||
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
|
||||
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
|
||||
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
|
||||
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
|
||||
|
||||
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
|
||||
|
||||
qkvs = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is None:
|
||||
continue
|
||||
if config.num_kv_heads == config.num_heads:
|
||||
qkv_einsum = lora.Einsum(
|
||||
shape=(3, config.num_heads, config.width, config.head_dim),
|
||||
name=_name("qkv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
|
||||
else:
|
||||
q_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.width, config.head_dim),
|
||||
name=_name("q_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
q = q_einsum("BTD,NDH->BTNH", x)
|
||||
kv_einsum = lora.Einsum(
|
||||
shape=(2, config.num_kv_heads, config.width, config.head_dim),
|
||||
name=_name("kv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
qkvs.append((q, k, v))
|
||||
|
||||
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
|
||||
|
||||
q = _apply_rope(q, positions=positions)
|
||||
q *= self.configs[0].head_dim ** -0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions)
|
||||
|
||||
# should still be half-precision here (if input was half-precision)
|
||||
assert q.dtype == k.dtype == v.dtype == dtype
|
||||
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
k = jnp.concatenate([cache_k, k], axis=1)
|
||||
v = jnp.concatenate([cache_v, v], axis=1)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
|
||||
out = []
|
||||
start = 0
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
end = start + x.shape[1]
|
||||
out_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.head_dim, config.width),
|
||||
name=_name("attn_vec_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
|
||||
start = end
|
||||
else:
|
||||
out.append(None)
|
||||
|
||||
return out, (k, v)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
configs: tuple[Config, ...]
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
|
||||
|
||||
attn = Attention(configs=self.configs, name="attn")
|
||||
|
||||
pre_attn = []
|
||||
gates = []
|
||||
for i, x in enumerate(xs):
|
||||
if x is not None:
|
||||
x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
||||
pre_attn.append(x)
|
||||
gates.append(gate if x is not None else None)
|
||||
|
||||
pre_attn = sharding.activation_sharding_constraint(pre_attn)
|
||||
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
|
||||
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
|
||||
post_attn = sharding.activation_sharding_constraint(post_attn)
|
||||
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
out = []
|
||||
gates = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
||||
x = lora.FeedForward( # noqa: PLW2901
|
||||
features=config.width,
|
||||
hidden_dim=config.mlp_dim,
|
||||
name=_name("mlp", i),
|
||||
lora_config=config.lora_configs.get("ffn"),
|
||||
)(x)
|
||||
out.append(x)
|
||||
gates.append(gate if x is not None else None)
|
||||
|
||||
out = sharding.activation_sharding_constraint(out)
|
||||
out = jax.tree.map(lambda x: drop(x, deterministic), out)
|
||||
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
return xs, kv_cache
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""Transformer model, supporting a mixture of different weights for different tokens."""
|
||||
|
||||
configs: Sequence[Config] # list of configs, one for each expert
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
adarms: bool = False
|
||||
|
||||
def setup(self):
|
||||
# all experts must have the same depth
|
||||
assert all(config.depth == self.configs[0].depth for config in self.configs)
|
||||
|
||||
self.embedder = Embedder(
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
embed_dim=self.configs[0].width, # embedder for first expert only
|
||||
name="embedder",
|
||||
)
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=False,
|
||||
static_argnums=(5,), # 0=self, 6=deterministic
|
||||
policy=jax.checkpoint_policies.nothing_saveable,
|
||||
)
|
||||
self.layers = nn.scan(
|
||||
block_cls,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=(
|
||||
0,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
|
||||
length=self.configs[0].depth,
|
||||
)(
|
||||
configs=self.configs,
|
||||
dropout=self.dropout,
|
||||
dropout_bdims=self.dropout_bdims,
|
||||
)
|
||||
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
|
||||
|
||||
@at.typecheck
|
||||
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
|
||||
return self.embedder.encode(tokens).astype(self.embed_dtype)
|
||||
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
# list of token arrays, one for each expert, or None if that expert should not be run
|
||||
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
|
||||
positions: at.Int[at.Array, "b t"],
|
||||
mask: at.Bool[at.Array, "b t s"],
|
||||
adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
|
||||
*,
|
||||
kv_cache: KVCache | None = None,
|
||||
deterministic: bool = True,
|
||||
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
|
||||
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
|
||||
mask = jnp.asarray(mask)[:, None, :, :]
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None] * len(self.configs)
|
||||
|
||||
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
|
||||
|
||||
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
|
||||
|
||||
return [
|
||||
f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
|
||||
], kv_cache
|
||||
|
||||
def init(self, use_adarms: Sequence[bool]):
|
||||
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
||||
self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
|
||||
self(
|
||||
[jnp.zeros((1, 1, c.width)) for c in self.configs],
|
||||
jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
|
||||
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
|
||||
adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
|
||||
)
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
|
||||
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
|
||||
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
|
||||
# here.
|
||||
return res.astype(x.dtype)
|
||||
|
||||
|
||||
def _name(name, i):
|
||||
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
|
||||
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
|
||||
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
|
||||
# and the action expert.
|
||||
if i == 0:
|
||||
return name
|
||||
return f"{name}_{i}"
|
||||
|
||||
|
||||
def _gated_residual(x, y, gate):
|
||||
assert (x is None) == (y is None)
|
||||
if x is None:
|
||||
return None
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
437
policy/openpi-InternData-A1/src/openpi/models/gemma_fast.py
Normal file
437
policy/openpi-InternData-A1/src/openpi/models/gemma_fast.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
|
||||
Used for FAST autoregressive policies.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
Variant = Literal["gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant):
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "gemma_2b":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
}
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
"lora_configs": {
|
||||
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
},
|
||||
}
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Einsum(nn.Module):
|
||||
shape: tuple[int, ...]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
|
||||
return jnp.einsum(eqn, x, w)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype) # return in original dtype
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.zeros_init(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
features: int
|
||||
head_dim: int
|
||||
|
||||
cache_dtype: str | None = None
|
||||
|
||||
lora_config: lora.LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
self.qkv_einsum = lora.Einsum(
|
||||
shape=(3, self.num_heads, self.features, self.head_dim),
|
||||
name="qkv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
else:
|
||||
self.q_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.features, self.head_dim),
|
||||
name="q_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.kv_einsum = lora.Einsum(
|
||||
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
||||
name="kv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.attn_vec_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.head_dim, self.features),
|
||||
name="attn_vec_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
|
||||
def _init_cache(self, k, v, cache_size):
|
||||
"""Initialize KV cache"""
|
||||
prefill_len = k.shape[1]
|
||||
pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
|
||||
cache_dtype = self.cache_dtype or k.dtype
|
||||
k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
|
||||
v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
|
||||
idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
|
||||
return idx, k_cache, v_cache
|
||||
|
||||
def _update_cache(self, k, v, idx, k_cache, v_cache):
|
||||
"""Update KV cache with new values"""
|
||||
assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
|
||||
indices = (0, idx[0], 0, 0)
|
||||
cache_dtype = self.cache_dtype or k.dtype
|
||||
k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
|
||||
v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
|
||||
idx_new = idx + 1
|
||||
return idx_new, k_new, v_new
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
|
||||
else:
|
||||
q = self.q_einsum("BTD,NDH->BTNH", x)
|
||||
k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
|
||||
q = _apply_rope(q, positions=positions) # promotes to float32
|
||||
q *= self.head_dim**-0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions) # promotes to float32
|
||||
|
||||
if kv_cache is None:
|
||||
idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
|
||||
else:
|
||||
idx, k_cache, v_cache = kv_cache
|
||||
idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
|
||||
|
||||
k, v = k_cache, v_cache
|
||||
kv_cache = (idx, k_cache, v_cache)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
embed_dim: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
cache_dtype: str | None = None
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
def setup(self):
|
||||
self.pre_attention_norm = RMSNorm()
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
features=self.embed_dim,
|
||||
head_dim=self.head_dim,
|
||||
cache_dtype=self.cache_dtype,
|
||||
lora_config=self.lora_configs.get("attn"),
|
||||
)
|
||||
self.pre_ffw_norm = RMSNorm()
|
||||
self.mlp = lora.FeedForward(
|
||||
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
|
||||
)
|
||||
if self.dropout:
|
||||
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
||||
else:
|
||||
self.drop = lambda x, _: x
|
||||
|
||||
def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
|
||||
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
||||
inputs_normalized = self.pre_attention_norm(x)
|
||||
attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
|
||||
attn_output = self.drop(attn_output, deterministic)
|
||||
attn_output += x
|
||||
residual = attn_output
|
||||
attn_output = self.pre_ffw_norm(attn_output)
|
||||
outputs = self.mlp(attn_output)
|
||||
outputs = self.drop(outputs, deterministic)
|
||||
outputs = residual + outputs
|
||||
return outputs, kv_cache
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""gemma model."""
|
||||
|
||||
variant: str
|
||||
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
cache_dtype: str | None = None
|
||||
|
||||
scan: bool = False
|
||||
remat_policy: str = "none"
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self,
|
||||
tokens=None,
|
||||
embedded_prefix=None,
|
||||
embed_only=False, # noqa: FBT002
|
||||
pre_logits=None,
|
||||
positions=None,
|
||||
mask=None,
|
||||
decode=False, # noqa: FBT002
|
||||
kv_cache=None,
|
||||
deterministic=True, # noqa: FBT002
|
||||
return_prelogits=False, # noqa: FBT002
|
||||
):
|
||||
"""Embed only, or complete forward pass.
|
||||
|
||||
Args:
|
||||
tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
|
||||
embedded_prefix: Optional prefix that is already embedded.
|
||||
embed_only: Whether to compute embeddings only.
|
||||
pre_logits: If present computes logits from pre_logits and returns.
|
||||
positions: Optional `[B, T]` allows to specify the absolute position of
|
||||
the tokens.
|
||||
mask: Optional attention mask `[B, T, S]`.
|
||||
decode: Whether to use kv-cache. Caller must pass masks and positions.
|
||||
deterministic: Forwarded to all dropout layers.
|
||||
return_prelogits: Whether to return the pre-logits.
|
||||
|
||||
Returns:
|
||||
If `embed_only=False`, then `(logits, out)` will be returned.
|
||||
If `embed_only=True`, then the embeddings will be returned.
|
||||
If `return_prelogits=True`, then the pre-logits will be returned.
|
||||
"""
|
||||
out = {}
|
||||
|
||||
embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
|
||||
|
||||
if pre_logits is not None:
|
||||
x = out["pre_logits"] = pre_logits
|
||||
logits = out["logits"] = embedder.decode(x)
|
||||
return logits, out
|
||||
|
||||
x = []
|
||||
if embedded_prefix is not None:
|
||||
x.append(embedded_prefix)
|
||||
if tokens is not None:
|
||||
x.append(embedder.encode(tokens))
|
||||
|
||||
x = jnp.concatenate(x, axis=-2)
|
||||
x = x.astype(self.embed_dtype)
|
||||
batch_size, seq_len, width = x.shape
|
||||
|
||||
if embed_only:
|
||||
return x
|
||||
|
||||
if decode:
|
||||
assert positions is not None and mask is not None, ( # noqa: PT018
|
||||
"Must explicitly pass positions and mask for decoding."
|
||||
)
|
||||
|
||||
if positions is None:
|
||||
positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
|
||||
assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
|
||||
|
||||
if mask is None:
|
||||
mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
|
||||
if mask.ndim == 3:
|
||||
mask = mask[:, None, :, :]
|
||||
cache_size = max(seq_len, mask.shape[-1])
|
||||
assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
|
||||
|
||||
if self.remat_policy == "none":
|
||||
block_cls = Block
|
||||
else:
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=not self.scan,
|
||||
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
|
||||
policy=getattr(jax.checkpoint_policies, self.remat_policy),
|
||||
)
|
||||
|
||||
block_kw = {
|
||||
"num_heads": self.num_heads,
|
||||
"head_dim": self.head_dim,
|
||||
"num_kv_heads": self.num_kv_heads,
|
||||
"embed_dim": width,
|
||||
"hidden_dim": self.mlp_dim,
|
||||
"dropout": self.dropout,
|
||||
"dropout_bdims": self.dropout_bdims,
|
||||
"cache_dtype": self.cache_dtype,
|
||||
"lora_configs": self.lora_configs,
|
||||
}
|
||||
layers = self.scope.push("layers")
|
||||
blocks = [
|
||||
nn.scan(
|
||||
block_cls,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
|
||||
length=self.depth,
|
||||
)(parent=layers, **block_kw)
|
||||
]
|
||||
for block in blocks:
|
||||
x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
|
||||
|
||||
assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
|
||||
out["encoded"] = x
|
||||
|
||||
x = RMSNorm(name="final_norm")(x)
|
||||
out["pre_logits"] = x
|
||||
if return_prelogits:
|
||||
return x, kv_cache, out
|
||||
|
||||
x = embedder.decode(x)
|
||||
out["logits"] = x
|
||||
|
||||
return x, kv_cache, out
|
||||
|
||||
def init(self):
|
||||
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
||||
self(jnp.zeros((1, 1), dtype=jnp.int32))
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
return res
|
||||
148
policy/openpi-InternData-A1/src/openpi/models/lora.py
Normal file
148
policy/openpi-InternData-A1/src/openpi/models/lora.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import math
|
||||
import re
|
||||
|
||||
import flax.linen as nn
|
||||
import flax.struct as struct
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
|
||||
@struct.dataclass
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
# LoRA rank.
|
||||
rank: int
|
||||
# LoRA scaling factor.
|
||||
alpha: float = 1.0
|
||||
# Initialization function for LoRA parameters.
|
||||
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
|
||||
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
|
||||
rslora: bool = False
|
||||
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
|
||||
axes: tuple[int, int] = (-2, -1)
|
||||
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
|
||||
label: str = "L"
|
||||
|
||||
@property
|
||||
def scaling_value(self) -> float:
|
||||
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
|
||||
|
||||
|
||||
class Einsum(nn.Module):
|
||||
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
|
||||
|
||||
# Shape of the weight.
|
||||
shape: tuple[int, ...]
|
||||
# Initialization function for the weight.
|
||||
init_fn: nn.initializers.Initializer = nn.initializers.zeros
|
||||
# If not None, apply LoRA to the weight.
|
||||
lora_config: LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
self.w = self.param("w", self.init_fn, self.shape)
|
||||
|
||||
if config := self.lora_config:
|
||||
# Setup LoRA parameters.
|
||||
shape_a, shape_b = list(self.shape), list(self.shape)
|
||||
shape_a[config.axes[1]] = config.rank
|
||||
shape_b[config.axes[0]] = config.rank
|
||||
self.w_a = self.param("lora_a", config.init_fn, shape_a)
|
||||
self.w_b = self.param("lora_b", config.init_fn, shape_b)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn: str, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
result = jnp.einsum(eqn, x, self.w.astype(dtype))
|
||||
|
||||
if config := self.lora_config:
|
||||
eqn_a, eqn_b = self._make_lora_eqns(eqn)
|
||||
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
|
||||
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
|
||||
result = result + lora * config.scaling_value
|
||||
|
||||
return result
|
||||
|
||||
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
|
||||
if "L" in eqn:
|
||||
raise ValueError(f"L already in eqn: {eqn}")
|
||||
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
|
||||
raise ValueError(f"Unsupported einsum eqn: {eqn}")
|
||||
lhs, rhs, out = m.groups()
|
||||
|
||||
assert self.lora_config is not None
|
||||
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
|
||||
label = self.lora_config.label
|
||||
|
||||
a_rhs = rhs.replace(b_label, label)
|
||||
a_out = out.replace(b_label, label)
|
||||
eqn_a = f"{lhs},{a_rhs}->{a_out}"
|
||||
|
||||
b_rhs = rhs.replace(a_label, label)
|
||||
eqn_b = f"{a_out},{b_rhs}->{out}"
|
||||
|
||||
return eqn_a, eqn_b
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
# If not None, apply LoRA to the weight.
|
||||
lora_config: LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
self.w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
)
|
||||
self.w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
)
|
||||
self.w_gating_lora = None
|
||||
self.w_linear_lora = None
|
||||
if self.lora_config:
|
||||
# Setup LoRA parameters.
|
||||
# TODO: follow up with a simplified init_fn api.
|
||||
self.w_gating_lora = (
|
||||
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
|
||||
self.param(
|
||||
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
|
||||
),
|
||||
)
|
||||
self.w_linear_lora = (
|
||||
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
|
||||
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
|
||||
)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
ff_gate = self._dot(
|
||||
x,
|
||||
self.w_gating[0],
|
||||
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
|
||||
)
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = self._dot(
|
||||
x,
|
||||
self.w_gating[1],
|
||||
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
|
||||
)
|
||||
activations = gate_value * ff1
|
||||
|
||||
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
|
||||
base = jnp.dot(x, w.astype(x.dtype))
|
||||
if lora_weights is None:
|
||||
return base
|
||||
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
|
||||
94
policy/openpi-InternData-A1/src/openpi/models/lora_test.py
Normal file
94
policy/openpi-InternData-A1/src/openpi/models/lora_test.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.lora as lora
|
||||
|
||||
|
||||
def test_lora_einsum_params_shape():
|
||||
shape = (3, 8, 32, 4) # (3KDH)
|
||||
einsum = lora.Einsum(shape)
|
||||
lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
|
||||
lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
||||
eqn = "BSD,3KDH->3BSKH"
|
||||
|
||||
# Ensure that lora parameters are not initialized when LoRA is not used.
|
||||
params = einsum.init(key, eqn, x)
|
||||
assert "lora_a" not in params["params"]
|
||||
assert "lora_b" not in params["params"]
|
||||
|
||||
# Check that default axes work.
|
||||
params_lora0 = lora0.init(key, eqn, x)
|
||||
assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
|
||||
assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
|
||||
|
||||
# Check that user provided axes work.
|
||||
params_lora1 = lora1.init(key, eqn, x)
|
||||
assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
|
||||
assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
|
||||
|
||||
|
||||
def test_lora_einsum_same_output():
|
||||
shape = (3, 8, 32, 4) # (3KDH)
|
||||
einsum = lora.Einsum(shape)
|
||||
einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
||||
eqn = "BSD,3KDH->3BSKH"
|
||||
|
||||
params = einsum.init(key, eqn, x)
|
||||
output = einsum.apply(params, eqn, x)
|
||||
|
||||
params_lora = einsum_lora.init(key, eqn, x)
|
||||
output_lora = einsum_lora.apply(params_lora, eqn, x)
|
||||
|
||||
# Results are the same since the LoRA parameters are initialized to zeros.
|
||||
assert jnp.allclose(output, output_lora)
|
||||
|
||||
|
||||
def test_lora_ffn_params_shape():
|
||||
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
||||
ffn_lora = lora.FeedForward(
|
||||
features=8,
|
||||
hidden_dim=32,
|
||||
lora_config=lora.LoRAConfig(rank=2),
|
||||
)
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (2, 8))
|
||||
|
||||
params = ffn.init(key, x)
|
||||
assert params["params"]["gating_einsum"].shape == (2, 8, 32)
|
||||
assert params["params"]["linear"].shape == (32, 8)
|
||||
|
||||
params_lora = ffn_lora.init(key, x)
|
||||
assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
|
||||
assert params_lora["params"]["linear"].shape == (32, 8)
|
||||
assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
|
||||
assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
|
||||
assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
|
||||
assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
|
||||
|
||||
|
||||
def test_lora_ffn_same_output():
|
||||
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
||||
ffn_lora = lora.FeedForward(
|
||||
features=8,
|
||||
hidden_dim=32,
|
||||
lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
|
||||
)
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (2, 8))
|
||||
|
||||
params = ffn.init(key, x)
|
||||
output = ffn.apply(params, x)
|
||||
|
||||
params_lora = ffn_lora.init(key, x)
|
||||
output_lora = ffn_lora.apply(params_lora, x)
|
||||
|
||||
assert jnp.allclose(output, output_lora)
|
||||
332
policy/openpi-InternData-A1/src/openpi/models/model.py
Normal file
332
policy/openpi-InternData-A1/src/openpi/models/model.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import abc
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import augmax
|
||||
from flax import nnx
|
||||
from flax import struct
|
||||
from flax import traverse_util
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from openpi.models_pytorch import pi0_pytorch
|
||||
from openpi.shared import image_tools
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
|
||||
ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
"""Supported model types."""
|
||||
|
||||
PI0 = "pi0"
|
||||
PI0_FAST = "pi0_fast"
|
||||
PI05 = "pi05"
|
||||
|
||||
|
||||
# The model always expects these images
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
|
||||
# This may need change if we release a small model.
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
# Data format
|
||||
#
|
||||
# Data transforms produce the model input as a nested dictionary which is later converted
|
||||
# into `Obesrvation` and `Actions` objects. See below.
|
||||
#
|
||||
# In the dictory form, this data should look like:
|
||||
# {
|
||||
# # Observation data.
|
||||
# "image": {
|
||||
# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
|
||||
# ... # Additional camera views
|
||||
# },
|
||||
# "image_mask": {
|
||||
# "base_0_rgb": bool[*b], # True if image is valid
|
||||
# ... # Masks for additional views
|
||||
# },
|
||||
# "state": float32[*b, s], # Low-dimensional robot state
|
||||
# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
|
||||
# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
|
||||
# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
|
||||
# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
|
||||
#
|
||||
# # Actions data.
|
||||
# "actions": float32[*b ah ad]
|
||||
# }
|
||||
# where:
|
||||
# *b = batch dimensions
|
||||
# h,w = image height/width
|
||||
# s = state dimension
|
||||
# l = sequence length
|
||||
#
|
||||
@at.typecheck
|
||||
@struct.dataclass
|
||||
class Observation(Generic[ArrayT]):
|
||||
"""Holds observations, i.e., inputs to the model.
|
||||
|
||||
See `Observation.from_dict` to see the expected dictionary form. This is the format
|
||||
that should be produced by the data transforms.
|
||||
"""
|
||||
|
||||
# Images, in [-1, 1] float32.
|
||||
images: dict[str, at.Float[ArrayT, "*b h w c"]]
|
||||
# Image masks, with same keys as images.
|
||||
image_masks: dict[str, at.Bool[ArrayT, "*b"]]
|
||||
# Low-dimensional robot state.
|
||||
state: at.Float[ArrayT, "*b s"]
|
||||
|
||||
# Tokenized prompt.
|
||||
tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
|
||||
# Tokenized prompt mask.
|
||||
tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
|
||||
|
||||
# pi0-fast model specific fields.
|
||||
|
||||
# Token auto-regressive mask (for FAST autoregressive model).
|
||||
token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
|
||||
# Token loss mask (for FAST autoregressive model).
|
||||
token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
|
||||
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
|
||||
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
|
||||
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
|
||||
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
|
||||
# If images are uint8, convert them to [-1, 1] float32.
|
||||
for key in data["image"]:
|
||||
if data["image"][key].dtype == np.uint8:
|
||||
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||
elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
|
||||
data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
|
||||
return cls(
|
||||
images=data["image"],
|
||||
image_masks=data["image_mask"],
|
||||
state=data["state"],
|
||||
tokenized_prompt=data.get("tokenized_prompt"),
|
||||
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
|
||||
token_ar_mask=data.get("token_ar_mask"),
|
||||
token_loss_mask=data.get("token_loss_mask"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> at.PyTree[ArrayT]:
|
||||
"""Convert the Observation to a nested dict."""
|
||||
result = dataclasses.asdict(self)
|
||||
result["image"] = result.pop("images")
|
||||
result["image_mask"] = result.pop("image_masks")
|
||||
return result
|
||||
|
||||
|
||||
# Defines the format of the actions. This field is included as "actions" inside the dictionary
|
||||
# produced by the data transforms.
|
||||
Actions = at.Float[ArrayT, "*b ah ad"]
|
||||
|
||||
|
||||
def preprocess_observation(
|
||||
rng: at.KeyArrayLike | None,
|
||||
observation: Observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
) -> Observation:
|
||||
"""Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
|
||||
filling in a default image mask (if necessary).
|
||||
"""
|
||||
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for augmax.
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
transforms = []
|
||||
if "wrist" not in key:
|
||||
height, width = image.shape[1:3]
|
||||
transforms += [
|
||||
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
|
||||
augmax.Resize(width, height),
|
||||
augmax.Rotate((-5, 5)),
|
||||
]
|
||||
transforms += [
|
||||
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
|
||||
]
|
||||
sub_rngs = jax.random.split(rng, image.shape[0])
|
||||
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
|
||||
|
||||
# Back to [-1, 1].
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
|
||||
else:
|
||||
out_masks[key] = jnp.asarray(observation.image_masks[key])
|
||||
|
||||
return Observation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BaseModelConfig(abc.ABC):
|
||||
"""Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
|
||||
method to create the corresponding model.
|
||||
"""
|
||||
|
||||
# Action space dimension.
|
||||
action_dim: int
|
||||
# Action sequence length.
|
||||
action_horizon: int
|
||||
# Tokenized prompt maximum length.
|
||||
max_token_len: int
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_type(self) -> ModelType:
|
||||
"""The model type."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create(self, rng: at.KeyArrayLike) -> "BaseModel":
|
||||
"""Create a new model, initializing parameters."""
|
||||
|
||||
def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
|
||||
"""Create a model with the given parameters."""
|
||||
model = nnx.eval_shape(self.create, jax.random.key(0))
|
||||
graphdef, state = nnx.split(model)
|
||||
if remove_extra_params:
|
||||
params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
|
||||
at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
|
||||
state.replace_by_pure_dict(params)
|
||||
return nnx.merge(graphdef, state)
|
||||
|
||||
def load_pytorch(self, train_config, weight_path: str):
|
||||
logger.info(f"train_config: {train_config}")
|
||||
model = pi0_pytorch.PI0Pytorch(config=train_config.model)
|
||||
safetensors.torch.load_model(model, weight_path)
|
||||
return model
|
||||
|
||||
@abc.abstractmethod
|
||||
def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
|
||||
"""Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
|
||||
|
||||
def fake_obs(self, batch_size: int = 1) -> Observation:
|
||||
observation_spec, _ = self.inputs_spec(batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
|
||||
|
||||
def fake_act(self, batch_size: int = 1) -> Actions:
|
||||
_, action_spec = self.inputs_spec(batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseModel(nnx.Module, abc.ABC):
|
||||
"""Base class for all model implementations. Specific models should inherit from this class. They should call
|
||||
super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
|
||||
"""
|
||||
|
||||
action_dim: int
|
||||
action_horizon: int
|
||||
max_token_len: int
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: Observation,
|
||||
actions: Actions,
|
||||
*,
|
||||
train: bool = False,
|
||||
) -> at.Float[at.Array, "*b ah"]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...
|
||||
|
||||
|
||||
def restore_params(
|
||||
params_path: pathlib.Path | str,
|
||||
*,
|
||||
restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
|
||||
dtype: jnp.dtype | None = None,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
) -> at.Params:
|
||||
"""Restores unstructured params PyTree from a checkpoint.
|
||||
|
||||
This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
|
||||
well as pre-trained checkpoints released for openpi.
|
||||
|
||||
Args:
|
||||
params_path: The local path to the checkpoint directory.
|
||||
restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
|
||||
dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
|
||||
sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
|
||||
|
||||
Returns:
|
||||
The restored params.
|
||||
"""
|
||||
params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path
|
||||
|
||||
if restore_type is jax.Array and sharding is None:
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
with ocp.PyTreeCheckpointer() as ckptr:
|
||||
metadata = ckptr.metadata(params_path)
|
||||
item = {"params": metadata["params"]}
|
||||
|
||||
params = ckptr.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree.map(
|
||||
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
|
||||
),
|
||||
),
|
||||
)["params"]
|
||||
|
||||
# If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
|
||||
# added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
if all(kp[-1] == "value" for kp in flat_params):
|
||||
flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
|
||||
return traverse_util.unflatten_dict(flat_params)
|
||||
94
policy/openpi-InternData-A1/src/openpi/models/model_test.py
Normal file
94
policy/openpi-InternData-A1/src/openpi/models/model_test.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from flax import nnx
|
||||
import jax
|
||||
import pytest
|
||||
|
||||
from openpi.models import model as _model
|
||||
from openpi.models import pi0_config
|
||||
from openpi.models import pi0_fast
|
||||
from openpi.shared import download
|
||||
from openpi.shared import nnx_utils
|
||||
|
||||
|
||||
def test_pi0_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_config.Pi0Config()
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_pi0_lora_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_pi0_fast_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_fast.Pi0FASTConfig()
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size,)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
|
||||
def test_pi0_fast_lora_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size,)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
lora_filter = nnx_utils.PathRegex(".*lora.*")
|
||||
model_state = nnx.state(model)
|
||||
|
||||
lora_state_elems = list(model_state.filter(lora_filter))
|
||||
assert len(lora_state_elems) > 0
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_model_restore():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_config.Pi0Config()
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
model = config.load(
|
||||
_model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
|
||||
)
|
||||
|
||||
loss = model.compute_loss(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = model.sample_actions(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
279
policy/openpi-InternData-A1/src/openpi/models/pi0.py
Normal file
279
policy/openpi-InternData-A1/src/openpi/models/pi0.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import logging
|
||||
|
||||
import einops
|
||||
import flax.nnx as nnx
|
||||
import flax.nnx.bridge as nnx_bridge
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import model as _model
|
||||
from openpi.models import pi0_config
|
||||
import openpi.models.gemma as _gemma
|
||||
import openpi.models.siglip as _siglip
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
|
||||
def make_attn_mask(input_mask, mask_ar):
|
||||
"""Adapted from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
|
||||
it and false where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
|
||||
cumsum = jnp.cumsum(mask_ar, axis=1)
|
||||
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
|
||||
return jnp.logical_and(attn_mask, valid_mask)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def posemb_sincos(
|
||||
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
|
||||
) -> at.Float[at.Array, "b {embedding_dim}"]:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if embedding_dim % 2 != 0:
|
||||
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
|
||||
|
||||
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
sinusoid_input = jnp.einsum(
|
||||
"i,j->ij",
|
||||
pos,
|
||||
1.0 / period * 2 * jnp.pi,
|
||||
precision=jax.lax.Precision.HIGHEST,
|
||||
)
|
||||
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
|
||||
|
||||
|
||||
class Pi0(_model.BaseModel):
|
||||
def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
|
||||
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
|
||||
self.pi05 = config.pi05
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
||||
# TODO: rewrite gemma in NNX. For now, use bridge.
|
||||
llm = nnx_bridge.ToNNX(
|
||||
_gemma.Module(
|
||||
configs=[paligemma_config, action_expert_config],
|
||||
embed_dtype=config.dtype,
|
||||
adarms=config.pi05,
|
||||
)
|
||||
)
|
||||
llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False])
|
||||
img = nnx_bridge.ToNNX(
|
||||
_siglip.Module(
|
||||
num_classes=paligemma_config.width,
|
||||
variant="So400m/14",
|
||||
pool_type="none",
|
||||
scan=True,
|
||||
dtype_mm=config.dtype,
|
||||
)
|
||||
)
|
||||
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
|
||||
self.PaliGemma = nnx.Dict(llm=llm, img=img)
|
||||
self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
|
||||
if config.pi05:
|
||||
self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
|
||||
self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
|
||||
else:
|
||||
self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
|
||||
self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
|
||||
self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
|
||||
self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
|
||||
|
||||
# This attribute gets automatically set by model.train() and model.eval().
|
||||
self.deterministic = True
|
||||
|
||||
@at.typecheck
|
||||
def embed_prefix(
|
||||
self, obs: _model.Observation
|
||||
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
|
||||
input_mask = []
|
||||
ar_mask = []
|
||||
tokens = []
|
||||
# embed images
|
||||
for name in obs.images:
|
||||
image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
|
||||
|
||||
tokens.append(image_tokens)
|
||||
input_mask.append(
|
||||
einops.repeat(
|
||||
obs.image_masks[name],
|
||||
"b -> b s",
|
||||
s=image_tokens.shape[1],
|
||||
)
|
||||
)
|
||||
# image tokens attend to each other
|
||||
ar_mask += [False] * image_tokens.shape[1]
|
||||
|
||||
# add language (aka tokenized inputs)
|
||||
if obs.tokenized_prompt is not None:
|
||||
tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
|
||||
tokens.append(tokenized_inputs)
|
||||
input_mask.append(obs.tokenized_prompt_mask)
|
||||
# full attention between image and language inputs
|
||||
ar_mask += [False] * tokenized_inputs.shape[1]
|
||||
tokens = jnp.concatenate(tokens, axis=1)
|
||||
input_mask = jnp.concatenate(input_mask, axis=1)
|
||||
ar_mask = jnp.array(ar_mask)
|
||||
return tokens, input_mask, ar_mask
|
||||
|
||||
@at.typecheck
|
||||
def embed_suffix(
|
||||
self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
|
||||
) -> tuple[
|
||||
at.Float[at.Array, "b s emb"],
|
||||
at.Bool[at.Array, "b s"],
|
||||
at.Bool[at.Array, " s"],
|
||||
at.Float[at.Array, "b emb"] | None,
|
||||
]:
|
||||
input_mask = []
|
||||
ar_mask = []
|
||||
tokens = []
|
||||
if not self.pi05:
|
||||
# add a single state token
|
||||
state_token = self.state_proj(obs.state)[:, None, :]
|
||||
tokens.append(state_token)
|
||||
input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
|
||||
# image/language inputs do not attend to state or actions
|
||||
ar_mask += [True]
|
||||
|
||||
action_tokens = self.action_in_proj(noisy_actions)
|
||||
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
|
||||
if self.pi05:
|
||||
# time MLP (for adaRMS)
|
||||
time_emb = self.time_mlp_in(time_emb)
|
||||
time_emb = nnx.swish(time_emb)
|
||||
time_emb = self.time_mlp_out(time_emb)
|
||||
time_emb = nnx.swish(time_emb)
|
||||
action_expert_tokens = action_tokens
|
||||
adarms_cond = time_emb
|
||||
else:
|
||||
# mix timestep + action information using an MLP (no adaRMS)
|
||||
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
|
||||
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
|
||||
action_time_tokens = self.action_time_mlp_in(action_time_tokens)
|
||||
action_time_tokens = nnx.swish(action_time_tokens)
|
||||
action_time_tokens = self.action_time_mlp_out(action_time_tokens)
|
||||
action_expert_tokens = action_time_tokens
|
||||
adarms_cond = None
|
||||
tokens.append(action_expert_tokens)
|
||||
input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))
|
||||
# image/language/state inputs do not attend to action tokens
|
||||
ar_mask += [True] + ([False] * (self.action_horizon - 1))
|
||||
tokens = jnp.concatenate(tokens, axis=1)
|
||||
input_mask = jnp.concatenate(input_mask, axis=1)
|
||||
ar_mask = jnp.array(ar_mask)
|
||||
return tokens, input_mask, ar_mask, adarms_cond
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
|
||||
) -> at.Float[at.Array, "*b ah"]:
|
||||
preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
|
||||
observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
|
||||
|
||||
batch_shape = actions.shape[:-2]
|
||||
noise = jax.random.normal(noise_rng, actions.shape)
|
||||
time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
|
||||
time_expanded = time[..., None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# one big forward pass of prefix + suffix at once
|
||||
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
|
||||
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
|
||||
input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
|
||||
ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
|
||||
attn_mask = make_attn_mask(input_mask, ar_mask)
|
||||
positions = jnp.cumsum(input_mask, axis=1) - 1
|
||||
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
|
||||
[prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
|
||||
)
|
||||
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
|
||||
|
||||
return jnp.mean(jnp.square(v_t - u_t), axis=-1)
|
||||
|
||||
@override
|
||||
def sample_actions(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: _model.Observation,
|
||||
*,
|
||||
num_steps: int | at.Int[at.Array, ""] = 10,
|
||||
noise: at.Float[at.Array, "b ah ad"] | None = None,
|
||||
) -> _model.Actions:
|
||||
observation = _model.preprocess_observation(None, observation, train=False)
|
||||
# note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
|
||||
# distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
|
||||
dt = -1.0 / num_steps
|
||||
batch_size = observation.state.shape[0]
|
||||
if noise is None:
|
||||
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
|
||||
|
||||
# first fill KV cache with a forward pass of the prefix
|
||||
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
|
||||
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
|
||||
positions = jnp.cumsum(prefix_mask, axis=1) - 1
|
||||
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
|
||||
|
||||
def step(carry):
|
||||
x_t, time = carry
|
||||
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
|
||||
observation, x_t, jnp.broadcast_to(time, batch_size)
|
||||
)
|
||||
# `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
|
||||
# other
|
||||
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
|
||||
# `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
|
||||
# prefix tokens
|
||||
prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
|
||||
# `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
|
||||
# generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
|
||||
full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
|
||||
assert full_attn_mask.shape == (
|
||||
batch_size,
|
||||
suffix_tokens.shape[1],
|
||||
prefix_tokens.shape[1] + suffix_tokens.shape[1],
|
||||
)
|
||||
# `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
|
||||
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
|
||||
|
||||
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
|
||||
[None, suffix_tokens],
|
||||
mask=full_attn_mask,
|
||||
positions=positions,
|
||||
kv_cache=kv_cache,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
assert prefix_out is None
|
||||
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
|
||||
|
||||
return x_t + dt * v_t, time + dt
|
||||
|
||||
def cond(carry):
|
||||
x_t, time = carry
|
||||
# robust to floating-point error
|
||||
return time >= -dt / 2
|
||||
|
||||
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
|
||||
return x_0
|
||||
108
policy/openpi-InternData-A1/src/openpi/models/pi0_config.py
Normal file
108
policy/openpi-InternData-A1/src/openpi/models/pi0_config.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import flax.nnx as nnx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import model as _model
|
||||
import openpi.models.gemma as _gemma
|
||||
from openpi.shared import array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openpi.models.pi0 import Pi0
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pi0Config(_model.BaseModelConfig):
|
||||
dtype: str = "bfloat16"
|
||||
paligemma_variant: _gemma.Variant = "gemma_2b"
|
||||
action_expert_variant: _gemma.Variant = "gemma_300m"
|
||||
|
||||
# Set the model specific defaults.
|
||||
action_dim: int = 32
|
||||
action_horizon: int = 50
|
||||
max_token_len: int = None # type: ignore
|
||||
# Pi05 has two differences from Pi0:
|
||||
# - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
|
||||
# - the action expert uses adaRMSNorm to inject the flow matching timestep
|
||||
pi05: bool = False
|
||||
# This config option is not used directly by the model, but it is read by the ModelTransformFactory.
|
||||
discrete_state_input: bool = None # type: ignore
|
||||
|
||||
def __post_init__(self):
|
||||
if self.max_token_len is None:
|
||||
object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
|
||||
if self.discrete_state_input is None:
|
||||
object.__setattr__(self, "discrete_state_input", self.pi05)
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_type(self) -> _model.ModelType:
|
||||
if self.pi05:
|
||||
return _model.ModelType.PI05
|
||||
return _model.ModelType.PI0
|
||||
|
||||
@override
|
||||
def create(self, rng: at.KeyArrayLike) -> "Pi0":
|
||||
from openpi.models.pi0 import Pi0
|
||||
|
||||
return Pi0(self, rngs=nnx.Rngs(rng))
|
||||
|
||||
@override
|
||||
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
|
||||
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
|
||||
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
|
||||
|
||||
with at.disable_typechecking():
|
||||
observation_spec = _model.Observation(
|
||||
images={
|
||||
"base_0_rgb": image_spec,
|
||||
"left_wrist_0_rgb": image_spec,
|
||||
"right_wrist_0_rgb": image_spec,
|
||||
},
|
||||
image_masks={
|
||||
"base_0_rgb": image_mask_spec,
|
||||
"left_wrist_0_rgb": image_mask_spec,
|
||||
"right_wrist_0_rgb": image_mask_spec,
|
||||
},
|
||||
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
|
||||
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
||||
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
|
||||
)
|
||||
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
|
||||
|
||||
return observation_spec, action_spec
|
||||
|
||||
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
||||
"""Returns the freeze filter based on the model config."""
|
||||
filters = []
|
||||
has_lora = False
|
||||
gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
|
||||
action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
|
||||
if "lora" in self.paligemma_variant:
|
||||
filters.append(
|
||||
gemma_params_filter,
|
||||
)
|
||||
if "lora" not in self.action_expert_variant:
|
||||
# If only freeze gemma params, exclude action expert params.
|
||||
filters.append(
|
||||
nnx.Not(action_expert_params_filter),
|
||||
)
|
||||
has_lora = True
|
||||
elif "lora" in self.action_expert_variant:
|
||||
filters.append(
|
||||
action_expert_params_filter,
|
||||
)
|
||||
has_lora = True
|
||||
|
||||
if has_lora:
|
||||
# If any lora is used, exclude all lora params.
|
||||
filters.append(
|
||||
nnx.Not(nnx_utils.PathRegex(".*lora.*")),
|
||||
)
|
||||
if not filters:
|
||||
return nnx.Nothing
|
||||
return nnx.All(*filters)
|
||||
313
policy/openpi-InternData-A1/src/openpi/models/pi0_fast.py
Normal file
313
policy/openpi-InternData-A1/src/openpi/models/pi0_fast.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import flax.nnx as nnx
|
||||
import flax.nnx.bridge as nnx_bridge
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import model as _model
|
||||
import openpi.models.gemma_fast as _gemma
|
||||
import openpi.models.siglip as _siglip
|
||||
from openpi.shared import array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
PALIGEMMA_EOS_TOKEN = 1
|
||||
|
||||
|
||||
def make_attn_mask(input_mask, mask_ar):
|
||||
"""Adapted from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
|
||||
it and false where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
|
||||
cumsum = jnp.cumsum(mask_ar, axis=1)
|
||||
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
|
||||
return jnp.logical_and(attn_mask, valid_mask)
|
||||
|
||||
|
||||
@jax.vmap
|
||||
def left_to_right_align(x, input_mask, attn_mask):
|
||||
"""Converts input from left-align to right-aligned."""
|
||||
# Due to vmap, this is operating in a single example (not batch level).
|
||||
assert x.ndim == 2
|
||||
assert input_mask.ndim == 1
|
||||
assert attn_mask.ndim == 2
|
||||
assert x.shape[0] == input_mask.shape[0]
|
||||
assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
|
||||
seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1
|
||||
x = jnp.roll(x, -seqlen, axis=0)
|
||||
input_mask = jnp.roll(input_mask, -seqlen, axis=0)
|
||||
attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
|
||||
return x, input_mask, attn_mask
|
||||
|
||||
|
||||
def put_along_last_axis(arr, indices, values):
|
||||
"""Like np.put_along_axis(..., axis=-1), since jax is missing it."""
|
||||
assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim)
|
||||
onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
|
||||
put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot)
|
||||
put_values = jnp.einsum("...i,...in->...n", values, onehot)
|
||||
return jnp.where(put_mask, put_values, arr)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pi0FASTConfig(_model.BaseModelConfig):
|
||||
dtype: str = "bfloat16"
|
||||
paligemma_variant: _gemma.Variant = "gemma_2b"
|
||||
|
||||
# Set the model specific defaults.
|
||||
action_dim: int = 32
|
||||
action_horizon: int = 32
|
||||
max_token_len: int = 250
|
||||
|
||||
# Tokenizer for the fast model.
|
||||
fast_model_tokenizer: Any | None = None
|
||||
# Keyword arguments for the fast model tokenizer.
|
||||
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_type(self) -> _model.ModelType:
|
||||
return _model.ModelType.PI0_FAST
|
||||
|
||||
@override
|
||||
def create(self, rng: at.KeyArrayLike) -> "Pi0FAST":
|
||||
return Pi0FAST(self, rngs=nnx.Rngs(rng))
|
||||
|
||||
@override
|
||||
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
|
||||
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
|
||||
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
|
||||
|
||||
with at.disable_typechecking():
|
||||
observation_spec = _model.Observation(
|
||||
images={
|
||||
"base_0_rgb": image_spec,
|
||||
"base_1_rgb": image_spec,
|
||||
"wrist_0_rgb": image_spec,
|
||||
},
|
||||
image_masks={
|
||||
"base_0_rgb": image_mask_spec,
|
||||
"base_1_rgb": image_mask_spec,
|
||||
"wrist_0_rgb": image_mask_spec,
|
||||
},
|
||||
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
|
||||
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
||||
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
|
||||
token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
||||
token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_),
|
||||
)
|
||||
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
|
||||
|
||||
return observation_spec, action_spec
|
||||
|
||||
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
||||
"""Returns the freeze filter based on the model config."""
|
||||
if "lora" in self.paligemma_variant:
|
||||
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
|
||||
return nnx.Nothing
|
||||
|
||||
|
||||
class Pi0FAST(_model.BaseModel):
|
||||
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
|
||||
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
# TODO: rewrite gemma in NNX. For now, use bridge.
|
||||
llm = nnx_bridge.ToNNX(
|
||||
_gemma.Module(
|
||||
**paligemma_config,
|
||||
embed_dtype=config.dtype,
|
||||
cache_dtype=config.dtype,
|
||||
)
|
||||
)
|
||||
llm.lazy_init(rngs=rngs, method="init")
|
||||
img = nnx_bridge.ToNNX(
|
||||
_siglip.Module(
|
||||
num_classes=paligemma_config.width,
|
||||
variant="So400m/14",
|
||||
pool_type="none",
|
||||
scan=True,
|
||||
dtype_mm=config.dtype,
|
||||
)
|
||||
)
|
||||
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
|
||||
self.PaliGemma = nnx.Dict(llm=llm, img=img)
|
||||
|
||||
@at.typecheck
|
||||
def embed_inputs(
|
||||
self, obs: _model.Observation
|
||||
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]:
|
||||
input_mask = []
|
||||
ar_mask = []
|
||||
token_embeddings = []
|
||||
# embed images
|
||||
for name in obs.images:
|
||||
image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False)
|
||||
|
||||
token_embeddings.append(image_token_embeddings)
|
||||
input_mask.append(
|
||||
einops.repeat(
|
||||
obs.image_masks[name],
|
||||
"b -> b s",
|
||||
s=image_token_embeddings.shape[1],
|
||||
)
|
||||
)
|
||||
# image tokens attend to each other --> AR mask = 0
|
||||
ar_mask.append(0 * input_mask[-1])
|
||||
|
||||
# add tokenized inputs
|
||||
assert obs.tokenized_prompt is not None, "Tokenized prompt is required"
|
||||
assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required"
|
||||
assert obs.token_ar_mask is not None, "Token auto-regressive mask is required"
|
||||
tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True)
|
||||
token_embeddings.append(tokenized_inputs_embeddings)
|
||||
input_mask.append(obs.tokenized_prompt_mask)
|
||||
ar_mask.append(obs.token_ar_mask)
|
||||
|
||||
# return embeddings, input mask, and ar mask
|
||||
return (
|
||||
jnp.concatenate(token_embeddings, axis=1),
|
||||
jnp.concatenate(input_mask, axis=1),
|
||||
jnp.concatenate(ar_mask, axis=1),
|
||||
)
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
|
||||
) -> at.Float[at.Array, "*b ah"]:
|
||||
observation = _model.preprocess_observation(
|
||||
rng, observation, train=train, image_keys=list(observation.images.keys())
|
||||
)
|
||||
|
||||
# Compute inputs: one big forward pass of prefix + suffix at once
|
||||
input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
|
||||
attn_mask = make_attn_mask(input_mask, ar_mask)
|
||||
|
||||
# Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
|
||||
targets = jax.nn.one_hot(
|
||||
observation.tokenized_prompt[:, 1:],
|
||||
self.PaliGemma.llm.module.vocab_size,
|
||||
)
|
||||
|
||||
# Each input predicts *next* token, so we don't input the last token.
|
||||
pre_logits, _, _ = self.PaliGemma.llm(
|
||||
embedded_prefix=input_token_embeddings[:, :-1],
|
||||
mask=attn_mask[:, :-1, :-1],
|
||||
return_prelogits=True,
|
||||
)
|
||||
|
||||
# Only decode logits for the target tokens to save memory
|
||||
# (decoding matmul is large because it is a seq_len x vocab_size dense layer).
|
||||
logits, _ = self.PaliGemma.llm(
|
||||
pre_logits=pre_logits[:, -targets.shape[1] :],
|
||||
)
|
||||
logp = jax.nn.log_softmax(logits, axis=-1)
|
||||
|
||||
# Compute CE loss on token targets
|
||||
assert observation.token_loss_mask is not None, "Token loss mask is required"
|
||||
loss_mask = observation.token_loss_mask[:, 1:]
|
||||
token_pplx = jnp.sum(targets * logp, axis=-1)
|
||||
return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)
|
||||
|
||||
@override
|
||||
def sample_actions(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: _model.Observation,
|
||||
*,
|
||||
max_decoding_steps: int | at.Int[at.Array, ""] = 256,
|
||||
temperature: float = 0.0,
|
||||
) -> _model.Actions:
|
||||
# TODO: this is a hack to get the image keys.
|
||||
observation = _model.preprocess_observation(
|
||||
None, observation, train=False, image_keys=list(observation.images.keys())
|
||||
)
|
||||
|
||||
# embed inputs
|
||||
prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
|
||||
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
|
||||
|
||||
# left to right align all input token sequences
|
||||
prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
|
||||
prefix_token_embeddings, prefix_mask, prefix_attn_mask
|
||||
)
|
||||
prefill_size = prefix_token_embeddings.shape[1]
|
||||
prefill_len = jnp.sum(prefix_mask, axis=-1)
|
||||
prefix_start = prefill_size - prefill_len
|
||||
|
||||
# first fill KV cache with a forward pass of the prefix
|
||||
# pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)
|
||||
prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
|
||||
prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
|
||||
prefix_logits, kv_cache, _ = self.PaliGemma.llm(
|
||||
embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
|
||||
)
|
||||
|
||||
# prepare decoding -- final logit decodes the first token
|
||||
last_logit = prefix_logits[:, -1:]
|
||||
output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
|
||||
|
||||
def step(carry):
|
||||
rng, last_logit, output_tokens, cache, _, step = carry
|
||||
|
||||
# Sample token from last logit
|
||||
# Split RNG for this step
|
||||
rng, rng_step = jax.random.split(rng)
|
||||
token = jax.lax.cond(
|
||||
temperature > 0.0,
|
||||
lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
|
||||
lambda _: jnp.argmax(last_logit, axis=-1),
|
||||
operand=None,
|
||||
)
|
||||
output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
|
||||
|
||||
# Check for early stopping --> stop if all batch elements have EOS token
|
||||
has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
|
||||
all_eos = jnp.all(has_eos)
|
||||
|
||||
# Decode one step
|
||||
token_embedding = self.PaliGemma.llm(token, embed_only=True)
|
||||
positions = prefill_len[:, None] + step + 1
|
||||
mask = jnp.logical_and(
|
||||
jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
|
||||
jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
|
||||
< (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),
|
||||
)
|
||||
last_logit, kv_cache, _ = self.PaliGemma.llm(
|
||||
embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
|
||||
)
|
||||
|
||||
return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1
|
||||
|
||||
def cond(carry):
|
||||
_, _, _, _, all_eos, step = carry
|
||||
return (~all_eos) & (step < max_decoding_steps)
|
||||
|
||||
# Use lax.while_loop so we can jit the full decoding loop.
|
||||
_, _, output_tokens, _, _, _ = jax.lax.while_loop(
|
||||
cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)
|
||||
)
|
||||
return output_tokens
|
||||
46
policy/openpi-InternData-A1/src/openpi/models/pi0_test.py
Normal file
46
policy/openpi-InternData-A1/src/openpi/models/pi0_test.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import flax.nnx as nnx
|
||||
import jax
|
||||
|
||||
import openpi.models.pi0_config as _pi0_config
|
||||
|
||||
|
||||
def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:
|
||||
abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
|
||||
|
||||
freeze_filter = config.get_freeze_filter()
|
||||
return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()
|
||||
|
||||
|
||||
def test_pi0_full_finetune():
|
||||
config = _pi0_config.Pi0Config()
|
||||
state = _get_frozen_state(config)
|
||||
assert len(state) == 0
|
||||
|
||||
|
||||
def test_pi0_gemma_lora():
|
||||
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
|
||||
state = _get_frozen_state(config)
|
||||
assert len(state) == 9
|
||||
assert all("lora" not in p for p in state)
|
||||
assert all("llm" in p for p in state)
|
||||
assert all("_1" not in p for p in state)
|
||||
|
||||
|
||||
def test_pi0_action_expert_lora():
|
||||
config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora")
|
||||
state = _get_frozen_state(config)
|
||||
# excluding embedder, rest of the params should be same as gemma_lora.
|
||||
assert len(state) == 8
|
||||
assert all("lora" not in p for p in state)
|
||||
assert all("llm" in p for p in state)
|
||||
# all frozen params should have _1 in their path since it's the action expert.
|
||||
assert all(any("_1" in p for p in path) for path in state)
|
||||
|
||||
|
||||
def test_pi0_all_lora():
|
||||
config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora")
|
||||
state = _get_frozen_state(config)
|
||||
# sum of gemma_lora and action_expert_lora's frozen params.
|
||||
assert len(state) == 17
|
||||
assert all("lora" not in p for p in state)
|
||||
assert all("llm" in p for p in state)
|
||||
373
policy/openpi-InternData-A1/src/openpi/models/siglip.py
Normal file
373
policy/openpi-InternData-A1/src/openpi/models/siglip.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A refactored and simplified ViT adoptation for Pi, taken from big_vision."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
import openpi.training.sharding as sharding
|
||||
|
||||
|
||||
def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
|
||||
"""Follows the MoCo v3 logic."""
|
||||
y, x = jnp.mgrid[:h, :w]
|
||||
|
||||
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
|
||||
omega = jnp.arange(width // 4) / (width // 4 - 1)
|
||||
omega = 1.0 / (temperature**omega)
|
||||
y = jnp.einsum("m,d->md", y.flatten(), omega)
|
||||
x = jnp.einsum("m,d->md", x.flatten(), omega)
|
||||
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
|
||||
return jnp.asarray(pe, dtype)[None, :, :]
|
||||
|
||||
|
||||
def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
|
||||
if typ == "learn":
|
||||
return self.param(
|
||||
name,
|
||||
nn.initializers.normal(stddev=1 / np.sqrt(width)),
|
||||
(1, np.prod(seqshape), width),
|
||||
dtype,
|
||||
)
|
||||
if typ == "sincos2d":
|
||||
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
|
||||
raise ValueError(f"Unknown posemb type: {typ}")
|
||||
|
||||
|
||||
class MlpBlock(nn.Module):
|
||||
"""Transformer MLP / feed-forward block."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
dropout: float = 0.0
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
"""Applies Transformer MlpBlock module."""
|
||||
inits = {
|
||||
"kernel_init": nn.initializers.xavier_uniform(),
|
||||
"bias_init": nn.initializers.normal(stddev=1e-6),
|
||||
}
|
||||
|
||||
_, _, d = x.shape # n,l,d
|
||||
x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
|
||||
x = nn.gelu(x)
|
||||
x = nn.Dropout(rate=self.dropout)(x, deterministic)
|
||||
return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
|
||||
|
||||
|
||||
class Encoder1DBlock(nn.Module):
|
||||
"""Single transformer encoder block (MHSA + MLP)."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dropout: float = 0.0
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
out = {}
|
||||
x = sharding.activation_sharding_constraint(x)
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
y = out["sa"] = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype_mm,
|
||||
)(y, y)
|
||||
y = sharding.activation_sharding_constraint(y)
|
||||
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
||||
x = out["+sa"] = x + y
|
||||
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
y = out["mlp"] = MlpBlock(
|
||||
mlp_dim=self.mlp_dim,
|
||||
dropout=self.dropout,
|
||||
dtype_mm=self.dtype_mm,
|
||||
)(y, deterministic)
|
||||
y = sharding.activation_sharding_constraint(y)
|
||||
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
||||
x = out["+mlp"] = x + y
|
||||
x = sharding.activation_sharding_constraint(x)
|
||||
return x, out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Transformer Model Encoder for sequence to sequence translation."""
|
||||
|
||||
depth: int
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dropout: float = 0.0
|
||||
scan: bool = False
|
||||
remat_policy: str = "nothing_saveable"
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
out = {}
|
||||
|
||||
if self.scan:
|
||||
block = nn.remat(
|
||||
Encoder1DBlock,
|
||||
prevent_cse=False,
|
||||
static_argnums=(2,), # 0=self, 2=deterministic
|
||||
policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
|
||||
)
|
||||
x, scan_out = nn.scan(
|
||||
block,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=nn.broadcast,
|
||||
length=self.depth,
|
||||
)(
|
||||
name="encoderblock",
|
||||
dtype_mm=self.dtype_mm,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
)(x, deterministic)
|
||||
for lyr in range(self.depth):
|
||||
out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out)
|
||||
else:
|
||||
# Input Encoder
|
||||
for lyr in range(self.depth):
|
||||
block_cur = Encoder1DBlock(
|
||||
name=f"encoderblock_{lyr}",
|
||||
dtype_mm=self.dtype_mm,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
|
||||
out["pre_ln"] = x # Alias for last block, but without the number in it.
|
||||
|
||||
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out
|
||||
|
||||
|
||||
class MAPHead(nn.Module):
|
||||
"""Multihead Attention Pooling."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
n, _, d = x.shape # n,l,d
|
||||
probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype)
|
||||
probe = jnp.tile(probe, [n, 1, 1])
|
||||
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads,
|
||||
dtype=self.dtype_mm,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
)(probe, x)
|
||||
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)
|
||||
return x[:, 0]
|
||||
|
||||
|
||||
class _Module(nn.Module):
|
||||
"""ViT model."""
|
||||
|
||||
num_classes: int | None = None
|
||||
patch_size: Sequence[int] = (16, 16)
|
||||
width: int = 768
|
||||
depth: int = 12
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
posemb: str = "learn" # Can also be "sincos2d"
|
||||
rep_size: int | bool = False
|
||||
dropout: float = 0.0
|
||||
pool_type: str = "gap" # Can also be "map" or "tok"
|
||||
head_zeroinit: bool = True
|
||||
scan: bool = False
|
||||
# or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
|
||||
remat_policy: str = "nothing_saveable"
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, image, *, train=False):
|
||||
out = {}
|
||||
|
||||
# Kevin edit: do patch extraction and posemb in float32,
|
||||
# because I feel like it's a bit safer.
|
||||
image = jnp.asarray(image, jnp.float32)
|
||||
|
||||
# Patch extraction
|
||||
x = out["stem"] = nn.Conv(
|
||||
self.width,
|
||||
self.patch_size,
|
||||
strides=self.patch_size,
|
||||
padding="VALID",
|
||||
name="embedding",
|
||||
dtype=jnp.float32,
|
||||
)(image)
|
||||
|
||||
n, h, w, c = x.shape
|
||||
x = jnp.reshape(x, [n, h * w, c])
|
||||
|
||||
# Add posemb before adding extra token.
|
||||
x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32)
|
||||
|
||||
if self.pool_type == "tok":
|
||||
cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
|
||||
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
|
||||
|
||||
n, _, c = x.shape # n,l,d
|
||||
x = nn.Dropout(rate=self.dropout)(x, not train)
|
||||
|
||||
# Kevin edit: now cast back to dtype_mm (potentially half precision)
|
||||
x = x.astype(self.dtype_mm)
|
||||
|
||||
x, out["encoder"] = Encoder(
|
||||
depth=self.depth,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
scan=self.scan,
|
||||
remat_policy=self.remat_policy,
|
||||
dtype_mm=self.dtype_mm,
|
||||
name="Transformer",
|
||||
)(x, deterministic=not train)
|
||||
encoded = out["encoded"] = x
|
||||
|
||||
if self.pool_type == "map":
|
||||
x = out["head_input"] = MAPHead(
|
||||
num_heads=self.num_heads,
|
||||
mlp_dim=self.mlp_dim,
|
||||
dtype=self.dtype_mm,
|
||||
)(x)
|
||||
elif self.pool_type == "gap":
|
||||
x = out["head_input"] = jnp.mean(x, axis=1)
|
||||
elif self.pool_type == "0":
|
||||
x = out["head_input"] = x[:, 0]
|
||||
elif self.pool_type == "tok":
|
||||
x = out["head_input"] = x[:, 0]
|
||||
encoded = encoded[:, 1:]
|
||||
elif self.pool_type == "none":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown pool type: '{self.pool_type}'")
|
||||
|
||||
x_2d = jnp.reshape(encoded, [n, h, w, -1])
|
||||
|
||||
if self.rep_size:
|
||||
rep_size = self.width if self.rep_size is True else self.rep_size
|
||||
hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits")
|
||||
# NOTE: In the past we did not include tanh in pre_logits.
|
||||
# For few-shot, it should not matter much, as it whitens anyways.
|
||||
x_2d = nn.tanh(hid(x_2d))
|
||||
x = nn.tanh(hid(x))
|
||||
|
||||
out["pre_logits_2d"] = x_2d
|
||||
out["pre_logits"] = x
|
||||
|
||||
if self.num_classes:
|
||||
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
|
||||
head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw)
|
||||
x_2d = out["logits_2d"] = head(x_2d)
|
||||
x = out["logits"] = head(x)
|
||||
|
||||
return x, out
|
||||
|
||||
|
||||
def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802
|
||||
"""Factory function, because linen really don't like what I'm doing!"""
|
||||
return _Module(num_classes, **{**decode_variant(variant), **kw})
|
||||
|
||||
|
||||
def decode_variant(variant):
|
||||
"""Converts a string like "B" or "B/32" into a params dict."""
|
||||
if variant is None:
|
||||
return {}
|
||||
|
||||
v, patch = variant, {}
|
||||
if "/" in variant:
|
||||
v, patch = variant.split("/")
|
||||
patch = {"patch_size": (int(patch), int(patch))}
|
||||
|
||||
return {
|
||||
# pylint:disable=line-too-long
|
||||
# Reference: Table 2 of https://arxiv.org/abs/2106.04560.
|
||||
"width": {
|
||||
"mu": 32,
|
||||
"Ti": 192,
|
||||
"S": 384,
|
||||
"M": 512,
|
||||
"B": 768,
|
||||
"L": 1024,
|
||||
"So400m": 1152,
|
||||
"H": 1280,
|
||||
"g": 1408,
|
||||
"g-opt": 1536,
|
||||
"G": 1664,
|
||||
"G-opt": 1536,
|
||||
"e": 1792,
|
||||
}[v],
|
||||
"depth": {
|
||||
"mu": 1,
|
||||
"Ti": 12,
|
||||
"S": 12,
|
||||
"M": 12,
|
||||
"B": 12,
|
||||
"L": 24,
|
||||
"So400m": 27,
|
||||
"H": 32,
|
||||
"g": 40,
|
||||
"g-opt": 40,
|
||||
"G": 48,
|
||||
"G-opt": 48,
|
||||
"e": 56,
|
||||
}[v],
|
||||
"mlp_dim": {
|
||||
"mu": 128,
|
||||
"Ti": 768,
|
||||
"S": 1536,
|
||||
"M": 2048,
|
||||
"B": 3072,
|
||||
"L": 4096,
|
||||
"So400m": 4304,
|
||||
"H": 5120,
|
||||
"g": 6144,
|
||||
"g-opt": 6144,
|
||||
"G": 8192,
|
||||
"G-opt": 8192,
|
||||
"e": 15360,
|
||||
}[v],
|
||||
"num_heads": {
|
||||
"mu": 2,
|
||||
"Ti": 3,
|
||||
"S": 6,
|
||||
"M": 8,
|
||||
"B": 12,
|
||||
"L": 16,
|
||||
"So400m": 16,
|
||||
"H": 16,
|
||||
"g": 16,
|
||||
"g-opt": 16,
|
||||
"G": 16,
|
||||
"G-opt": 16,
|
||||
"e": 16,
|
||||
}[v],
|
||||
# pylint:enable=line-too-long
|
||||
**patch,
|
||||
}
|
||||
371
policy/openpi-InternData-A1/src/openpi/models/tokenizer.py
Normal file
371
policy/openpi-InternData-A1/src/openpi/models/tokenizer.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import sentencepiece
|
||||
from transformers import AutoProcessor
|
||||
|
||||
import openpi.models.utils.fsq_tokenizer as fsq_tokenizer
|
||||
import openpi.shared.download as download
|
||||
|
||||
|
||||
class PaligemmaTokenizer:
|
||||
def __init__(self, max_len: int = 48):
|
||||
self._max_len = max_len
|
||||
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ")
|
||||
if state is not None:
|
||||
# This is the Pi05 format, where the state is part of the discrete language input.
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
tokens = self._tokenizer.encode(full_prompt, add_bos=True)
|
||||
else:
|
||||
# This is the Pi0 format, where the state is part of the continuous action expert input.
|
||||
# tokenize "\n" separately as the "start of answer" token
|
||||
tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
mask = [True] * tokens_len + padding
|
||||
tokens = tokens + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
mask = [True] * self._max_len
|
||||
|
||||
return np.asarray(tokens), np.asarray(mask)
|
||||
|
||||
|
||||
class FASTTokenizer:
|
||||
def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"):
|
||||
self._max_len = max_len
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
# Instantiate FAST tokenizer
|
||||
self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
# Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
|
||||
action_tokens = self._fast_tokenizer(actions[None])[0]
|
||||
action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)
|
||||
|
||||
# Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
|
||||
postfix_tokens = (
|
||||
self._paligemma_tokenizer.encode("Action: ")
|
||||
+ action_tokens_in_pg.tolist()
|
||||
+ self._paligemma_tokenizer.encode("|", add_eos=True)
|
||||
)
|
||||
else:
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
return self._fast_tokenizer.decode(
|
||||
[action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim
|
||||
)[0]
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
###########################################################################
|
||||
## The tokenizers below are used for RoboArena baseline implementations. ##
|
||||
## They are *not* used for pi0-style models. ##
|
||||
###########################################################################
|
||||
|
||||
|
||||
class BinningTokenizer:
|
||||
"""
|
||||
Standard RT-2 / OpenVLA style binning tokenizer.
|
||||
"""
|
||||
|
||||
def __init__(self, max_len: int = 256, n_bins: int = 256):
|
||||
self._max_len = max_len
|
||||
self._n_bins = n_bins
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Tokenize a prompt and state into a sequence of tokens.
|
||||
|
||||
Args:
|
||||
prompt: The text prompt to tokenize.
|
||||
state: The state array to discretize and tokenize.
|
||||
actions: Must be None. Action encoding is not currently supported.
|
||||
|
||||
Returns:
|
||||
A tuple of (tokens, token_mask, ar_mask, targets).
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If actions is not None.
|
||||
"""
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
if len(action_tokens) < action_horizon * action_dim:
|
||||
return np.zeros([action_horizon, action_dim], dtype=np.float32)
|
||||
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
|
||||
return action_tokens / self._n_bins * 2 - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
class FSQTokenizer:
|
||||
"""
|
||||
FSQ tokenizer from the FAST paper baselines.
|
||||
"""
|
||||
|
||||
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
|
||||
self._max_len = max_len
|
||||
|
||||
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
|
||||
# Download tokenizer
|
||||
path = download.maybe_download(fsq_tokenizer_path)
|
||||
tok_path = os.path.join(path, os.listdir(path)[0])
|
||||
|
||||
# Split step from path
|
||||
step = int(tok_path.split("/")[-1])
|
||||
base_path = tok_path.rsplit("/", 1)[0]
|
||||
|
||||
mgr = ocp.CheckpointManager(
|
||||
base_path,
|
||||
item_handlers={
|
||||
"params": ocp.StandardCheckpointHandler(),
|
||||
"opt_state": ocp.StandardCheckpointHandler(),
|
||||
"config": ocp.JsonCheckpointHandler(),
|
||||
},
|
||||
options=ocp.CheckpointManagerOptions(max_to_keep=1),
|
||||
)
|
||||
|
||||
try:
|
||||
restored = mgr.restore(
|
||||
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
|
||||
)
|
||||
config = restored["config"]
|
||||
self._params = restored["params"]
|
||||
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
|
||||
) from e
|
||||
|
||||
# Compile tokenize and detokenize functions
|
||||
self._tokenize_fn = jax.jit(
|
||||
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
|
||||
)
|
||||
self._detokenize_fn = jax.jit(
|
||||
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
|
||||
)
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
try:
|
||||
# Move computation to CPU and compile on-demand
|
||||
device = jax.devices("cpu")[0]
|
||||
with jax.default_device(device):
|
||||
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
|
||||
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
|
||||
except Exception as e:
|
||||
logging.warning(f"Error decoding FSQ: {e}")
|
||||
return np.zeros((action_horizon, action_dim))
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
|
||||
from openpi.models import tokenizer as _tokenizer
|
||||
|
||||
|
||||
def test_tokenize():
|
||||
tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
|
||||
tokens, masks = tokenizer.tokenize("Hello, world!")
|
||||
|
||||
assert tokens.shape == (10,)
|
||||
assert masks.shape == (10,)
|
||||
|
||||
|
||||
def test_fast_tokenizer():
|
||||
prompt = "Hello, world!"
|
||||
state = np.random.rand(5).astype(np.float32)
|
||||
action = np.random.rand(3, 2).astype(np.float32)
|
||||
tokenizer = _tokenizer.FASTTokenizer(max_len=256)
|
||||
tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action)
|
||||
|
||||
assert tokens.shape == (256,)
|
||||
assert token_masks.shape == (256,)
|
||||
assert ar_masks.shape == (256,)
|
||||
assert loss_masks.shape == (256,)
|
||||
|
||||
act = tokenizer.extract_actions(tokens, 3, 2)
|
||||
assert act.shape == (3, 2)
|
||||
@@ -0,0 +1,472 @@
|
||||
import math
|
||||
from typing import Any, Literal
|
||||
|
||||
import chex
|
||||
from einops import einops
|
||||
from flax import linen as nn
|
||||
from flax.linen.module import Module
|
||||
from flax.linen.module import compact
|
||||
from flax.struct import dataclass
|
||||
from flax.typing import Array
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class FsqCodebook(nn.Module):
|
||||
input_dim: int
|
||||
target_codebook_size: int
|
||||
codebook_type: Literal["fsq", "lfq"]
|
||||
|
||||
_bins_per_dim: tuple[int] | None = None
|
||||
|
||||
@property
|
||||
def bins_per_dim(self) -> tuple[int]:
|
||||
if self._bins_per_dim is not None:
|
||||
return self._bins_per_dim
|
||||
|
||||
if self.codebook_type == "fsq":
|
||||
return self._get_bins_fsq(self.target_codebook_size)
|
||||
elif self.codebook_type == "lfq": # noqa: RET505
|
||||
return self._get_bins_lfq(self.target_codebook_size)
|
||||
elif self.codebook_type == "custom":
|
||||
return self._get_bins_custom(self.target_codebook_size)
|
||||
else:
|
||||
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
|
||||
|
||||
@property
|
||||
def place_values(self) -> jnp.ndarray:
|
||||
place_values = [1]
|
||||
for b in self.bins_per_dim[:-1]:
|
||||
place_values.append(place_values[-1] * b)
|
||||
return jnp.array(place_values)
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_fsq(target_codebook_size: int) -> tuple[int]:
|
||||
"""
|
||||
Get bins per dimension based on codebook size, from the original FSQ paper.
|
||||
"""
|
||||
if target_codebook_size == 2**8:
|
||||
return (8, 6, 5)
|
||||
elif target_codebook_size == 2**10: # noqa: RET505
|
||||
return (8, 5, 5, 5)
|
||||
elif target_codebook_size == 2**12:
|
||||
return (7, 5, 5, 5, 5)
|
||||
elif target_codebook_size == 2**14:
|
||||
return (8, 8, 8, 6, 5)
|
||||
elif target_codebook_size == 2**16:
|
||||
return (8, 8, 8, 5, 5, 5)
|
||||
else:
|
||||
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_custom(target_codebook_size: int) -> tuple[int]:
|
||||
if target_codebook_size == 2**8:
|
||||
return (16, 16)
|
||||
elif target_codebook_size == 2**10: # noqa: RET505
|
||||
return (32, 32)
|
||||
elif target_codebook_size == 2**12:
|
||||
return (64, 64)
|
||||
elif target_codebook_size == 2**14:
|
||||
return (128, 128)
|
||||
elif target_codebook_size == 2**16:
|
||||
return (256, 256)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_lfq(target_codebook_size: int) -> tuple[int]:
|
||||
"""
|
||||
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
|
||||
"""
|
||||
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
|
||||
|
||||
return (2,) * int(math.log2(target_codebook_size))
|
||||
|
||||
def setup(self):
|
||||
self.proj_down = nn.Dense(len(self.bins_per_dim))
|
||||
self.proj_up = nn.Dense(self.input_dim)
|
||||
|
||||
def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
|
||||
tokens, z = self.encode(inputs)
|
||||
output = self.decode(tokens, z_grad=z)
|
||||
return tokens, output
|
||||
|
||||
def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
|
||||
bases = jnp.array(self.bins_per_dim)
|
||||
|
||||
x = self.proj_down(inputs)
|
||||
z = jnp.tanh(x)
|
||||
|
||||
# Quantize
|
||||
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
|
||||
tokens = self.undigitize(digits)
|
||||
|
||||
return tokens, z
|
||||
|
||||
def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray:
|
||||
bases = jnp.array(self.bins_per_dim)
|
||||
digits = self.digitize(tokens)
|
||||
|
||||
z_q = digits / (bases - 1) * 2 - 1
|
||||
|
||||
if z_grad is not None:
|
||||
chex.assert_equal_shape([z_q, z_grad])
|
||||
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
|
||||
|
||||
return self.proj_up(z_q)
|
||||
|
||||
def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray:
|
||||
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
|
||||
|
||||
def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray:
|
||||
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return math.prod(self.bins_per_dim)
|
||||
|
||||
|
||||
class ResNetDownBlock(nn.Module):
|
||||
stride: int = 1
|
||||
n_filters: int = 64
|
||||
dropout_rate: float = 0.0
|
||||
group_size: int = 32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
|
||||
skip = x
|
||||
|
||||
if self.stride > 1 or x.shape[-1] != self.n_filters:
|
||||
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||
|
||||
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = nn.relu(x)
|
||||
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
|
||||
|
||||
return skip + x
|
||||
|
||||
|
||||
class ResNetUpBlock(nn.Module):
|
||||
stride: int = 1
|
||||
n_filters: int = 64
|
||||
dropout_rate: float = 0.0
|
||||
group_size: int = 32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
|
||||
skip = x
|
||||
|
||||
if self.stride > 1:
|
||||
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||
|
||||
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = nn.relu(x)
|
||||
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
|
||||
|
||||
return skip + x
|
||||
|
||||
|
||||
@dataclass
|
||||
class LfqCodebookOutput:
|
||||
tokens: jnp.ndarray
|
||||
z: jnp.ndarray
|
||||
z_q: jnp.ndarray
|
||||
token_log_probs: jnp.ndarray
|
||||
commit_loss: jnp.ndarray
|
||||
|
||||
|
||||
class LookupFreeQuantization(nn.Module):
|
||||
num_dims: int
|
||||
latent_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.codebook = jnp.array([-1, 1])
|
||||
self.activation = nn.tanh
|
||||
|
||||
self.project_down = nn.Dense(self.num_dims)
|
||||
self.project_up = nn.Dense(self.latent_dim)
|
||||
|
||||
def encode(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
z = self.project_down(z)
|
||||
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||
token_bits = jnp.argmin(token_squared_distances, axis=-1)
|
||||
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
|
||||
|
||||
def decode(self, tokens: jnp.ndarray) -> jnp.ndarray:
|
||||
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
|
||||
return self.project_up(self.codebook[token_bits])
|
||||
|
||||
def loss(self, x: jnp.ndarray) -> LfqCodebookOutput:
|
||||
z = self.project_down(x)
|
||||
z = self.activation(z)
|
||||
|
||||
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||
tokens = jnp.argmin(token_squared_distances, axis=-1)
|
||||
|
||||
token_bit_log_probs = -token_squared_distances
|
||||
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
|
||||
token_bit_expansions = jnp.bitwise_and(
|
||||
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
|
||||
).astype(jnp.int32)
|
||||
token_log_probs = (
|
||||
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
|
||||
+ token_bit_log_probs[..., 1] @ token_bit_expansions
|
||||
) # (batch_size, num_tokens, 2 ** num_dims)
|
||||
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
|
||||
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
|
||||
|
||||
z_q = self.codebook[tokens]
|
||||
commit_loss = jnp.square(z - z_q).mean()
|
||||
z_q = jax.lax.stop_gradient(z_q - z) + z
|
||||
|
||||
z_q = self.project_up(z_q)
|
||||
z = self.project_up(z)
|
||||
|
||||
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
|
||||
return LfqCodebookOutput(
|
||||
tokens=tokens,
|
||||
z=z,
|
||||
z_q=z_q,
|
||||
token_log_probs=jnp.zeros(()),
|
||||
commit_loss=commit_loss,
|
||||
)
|
||||
|
||||
|
||||
def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray:
|
||||
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
|
||||
|
||||
|
||||
class GeGLU(Module):
|
||||
"""Gated Linear Unit with GELU (GeGLU) activation function.
|
||||
GeGLU is a Flax layer that combines a linear transformation with a GELU
|
||||
activation function in a gating mechanism. It is often used in Transformer models
|
||||
to provide non-linear capabilities while preserving a strong linear component.
|
||||
|
||||
Attributes:
|
||||
features: the number of output features (default: None).
|
||||
"""
|
||||
|
||||
output_dim: int = -1
|
||||
|
||||
@compact
|
||||
def __call__(self, inputs: Array) -> Array:
|
||||
"""Applies the GeGLU activation to the inputs.
|
||||
Args:
|
||||
inputs: the nd-array to apply the GeGLU activation function to.
|
||||
Returns:
|
||||
The transformed input.
|
||||
"""
|
||||
output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim
|
||||
|
||||
x = nn.Dense(output_dim * 2)(inputs)
|
||||
x, gate = x[..., :output_dim], x[..., output_dim:]
|
||||
return x * nn.gelu(gate)
|
||||
|
||||
|
||||
class CrossAttentionLayer(nn.Module):
|
||||
dropout_rate: float = 0.0
|
||||
num_heads: int = None
|
||||
causal: bool = False
|
||||
mlp_ratio: float = 4.0
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self,
|
||||
x: jnp.ndarray,
|
||||
y: jnp.ndarray,
|
||||
*,
|
||||
mask_self: jnp.ndarray | None = None,
|
||||
mask_cross: jnp.ndarray | None = None,
|
||||
train: bool = True,
|
||||
) -> jnp.ndarray:
|
||||
d_embed = x.shape[-1]
|
||||
seq_len_q = x.shape[-2]
|
||||
seq_len_k = y.shape[-2]
|
||||
|
||||
if self.causal:
|
||||
# One block size will be 1
|
||||
bs_q = max(seq_len_q // seq_len_k, 1)
|
||||
bs_k = max(seq_len_k // seq_len_q, 1)
|
||||
|
||||
mask_self = nn.make_causal_mask(x[..., 0])
|
||||
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
|
||||
|
||||
# Self-attention block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads or d_embed // 64,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=not train,
|
||||
)(x, x, x, mask=mask_self)
|
||||
x = skip + x
|
||||
|
||||
# Cross-attention block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads or d_embed // 64,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=not train,
|
||||
)(x, y, y, mask=mask_cross)
|
||||
x = skip + x
|
||||
|
||||
# MLP block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = GeGLU()(x)
|
||||
x = nn.Dense(d_embed)(x)
|
||||
return skip + x
|
||||
|
||||
|
||||
def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray:
|
||||
seq_len, d_embed = shape
|
||||
|
||||
position = jnp.arange(0, seq_len, 1)
|
||||
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
|
||||
return jnp.concatenate(
|
||||
[
|
||||
jnp.sin(position[:, jnp.newaxis] * div_term),
|
||||
jnp.cos(position[:, jnp.newaxis] * div_term),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
|
||||
class TokenizerEncoderDecoder(nn.Module):
|
||||
num_tokens: int
|
||||
num_cross_tokens: int
|
||||
num_layers: int
|
||||
causal: bool
|
||||
|
||||
mlp_ratio: float = 4.0
|
||||
use_state_conditioning: bool = False
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self,
|
||||
y: jnp.ndarray,
|
||||
*,
|
||||
train: bool = True,
|
||||
state_conditioning: jnp.ndarray | None = None,
|
||||
mask: jnp.ndarray | None = None,
|
||||
) -> jnp.ndarray:
|
||||
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
|
||||
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
|
||||
|
||||
if mask is not None:
|
||||
# mask is (batch_dims..., num_cross_tokens)
|
||||
chex.assert_equal_shape([y[..., 0], mask])
|
||||
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
|
||||
else:
|
||||
attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens))
|
||||
|
||||
if self.use_state_conditioning:
|
||||
assert state_conditioning is not None, "State conditioning is required for this model."
|
||||
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
|
||||
y = jnp.concatenate([y, state_embed], axis=-2)
|
||||
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
|
||||
|
||||
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
|
||||
x, y, train=train, mask_self=None, mask_cross=attn_mask
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FsqAttentionTokenizer(nn.Module):
|
||||
embed_dim: int
|
||||
data_dim: int
|
||||
data_horizon: int
|
||||
num_tokens: int
|
||||
num_layers: int
|
||||
target_codebook_size: int
|
||||
causal: bool = False
|
||||
mlp_ratio: float = 2.0
|
||||
|
||||
bound: float | None = None
|
||||
|
||||
use_state_conditioning: bool = False
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001
|
||||
|
||||
def setup(self):
|
||||
self.proj = nn.Dense(self.embed_dim)
|
||||
self.encoder = TokenizerEncoderDecoder(
|
||||
num_tokens=self.num_tokens,
|
||||
num_cross_tokens=self.data_horizon,
|
||||
num_layers=self.num_layers,
|
||||
causal=self.causal,
|
||||
use_state_conditioning=self.use_state_conditioning,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
)
|
||||
self.codebook = FsqCodebook(
|
||||
input_dim=self.embed_dim,
|
||||
target_codebook_size=self.target_codebook_size,
|
||||
codebook_type="custom",
|
||||
)
|
||||
self.decoder = TokenizerEncoderDecoder(
|
||||
num_tokens=self.data_horizon,
|
||||
num_cross_tokens=self.num_tokens,
|
||||
num_layers=self.num_layers,
|
||||
causal=self.causal,
|
||||
use_state_conditioning=self.use_state_conditioning,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
)
|
||||
|
||||
self.proj_mean = nn.Dense(self.data_dim)
|
||||
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
|
||||
|
||||
def tokenize(
|
||||
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False
|
||||
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
||||
if self.bound is not None:
|
||||
action = jnp.clip(action, -self.bound, self.bound)
|
||||
|
||||
x = self.proj(action)
|
||||
x = self.encoder(x, train=train, state_conditioning=obs)
|
||||
|
||||
return self.codebook.encode(x)
|
||||
|
||||
def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray:
|
||||
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
|
||||
mean = self.proj_mean(x)
|
||||
return mean * self.out_scale
|
||||
|
||||
def loss(
|
||||
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True
|
||||
) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
|
||||
# Encode
|
||||
x = self.proj(action)
|
||||
z = self.encoder(x, train=train, state_conditioning=obs)
|
||||
|
||||
# Quantize
|
||||
tokens, z = self.codebook(z)
|
||||
|
||||
# Decode
|
||||
x = self.decoder(z, train=train, state_conditioning=obs)
|
||||
mean = self.proj_mean(x) * self.out_scale
|
||||
|
||||
mse = jnp.mean(jnp.square(action - mean))
|
||||
mae = jnp.mean(jnp.abs(action - mean))
|
||||
|
||||
return mse, {
|
||||
"mse": mse,
|
||||
"mae": mae,
|
||||
}
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
|
||||
"""
|
||||
Dummy for .init
|
||||
"""
|
||||
return self.loss(*args, **kwargs)
|
||||
307
policy/openpi-InternData-A1/src/openpi/models/vit.py
Normal file
307
policy/openpi-InternData-A1/src/openpi/models/vit.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright 2024 Google LLC.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi.models import resnet as models_resnet
|
||||
|
||||
Array = Any
|
||||
PRNGKey = Any
|
||||
Shape = tuple[int]
|
||||
Dtype = Any
|
||||
|
||||
|
||||
class IdentityLayer(nn.Module):
|
||||
"""Identity layer, convenient for giving a name to an array."""
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class AddPositionEmbs(nn.Module):
|
||||
"""Adds learned positional embeddings to the inputs.
|
||||
|
||||
Attributes:
|
||||
posemb_init: positional embedding initializer.
|
||||
"""
|
||||
|
||||
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
|
||||
param_dtype: Dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
"""Applies the AddPositionEmbs module.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to the layer.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape `(bs, timesteps, in_dim)`.
|
||||
"""
|
||||
# inputs.shape is (batch_size, seq_len, emb_dim).
|
||||
assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}"
|
||||
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
|
||||
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype)
|
||||
return inputs + pe
|
||||
|
||||
|
||||
class MlpBlock(nn.Module):
|
||||
"""Transformer MLP / feed-forward block."""
|
||||
|
||||
mlp_dim: int
|
||||
dtype: Dtype = jnp.float32
|
||||
param_dtype: Dtype = jnp.float32
|
||||
out_dim: int | None = None
|
||||
dropout_rate: float = 0.1
|
||||
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
|
||||
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, *, deterministic):
|
||||
"""Applies Transformer MlpBlock module."""
|
||||
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
|
||||
x = nn.Dense(
|
||||
features=self.mlp_dim,
|
||||
dtype=self.dtype,
|
||||
param_dtype=self.param_dtype,
|
||||
kernel_init=self.kernel_init,
|
||||
bias_init=self.bias_init,
|
||||
)( # pytype: disable=wrong-arg-types
|
||||
inputs
|
||||
)
|
||||
x = nn.gelu(x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
|
||||
output = nn.Dense(
|
||||
features=actual_out_dim,
|
||||
dtype=self.dtype,
|
||||
param_dtype=self.param_dtype,
|
||||
kernel_init=self.kernel_init,
|
||||
bias_init=self.bias_init,
|
||||
)( # pytype: disable=wrong-arg-types
|
||||
x
|
||||
)
|
||||
return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
|
||||
|
||||
|
||||
class Encoder1DBlock(nn.Module):
|
||||
"""Transformer encoder layer.
|
||||
|
||||
Attributes:
|
||||
inputs: input data.
|
||||
mlp_dim: dimension of the mlp on top of attention block.
|
||||
dtype: the dtype of the computation (default: float32).
|
||||
dropout_rate: dropout rate.
|
||||
attention_dropout_rate: dropout for attention heads.
|
||||
deterministic: bool, deterministic or not (to apply dropout).
|
||||
num_heads: Number of heads in nn.MultiHeadDotProductAttention
|
||||
"""
|
||||
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
dtype: Dtype = jnp.float32
|
||||
dropout_rate: float = 0.1
|
||||
attention_dropout_rate: float = 0.1
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, deterministic):
|
||||
"""Applies Encoder1DBlock module.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to the layer.
|
||||
deterministic: Dropout will not be applied when set to true.
|
||||
|
||||
Returns:
|
||||
output after transformer encoder block.
|
||||
"""
|
||||
|
||||
# Attention block.
|
||||
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
|
||||
x = nn.LayerNorm(dtype=self.dtype)(inputs)
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
dtype=self.dtype,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
broadcast_dropout=False,
|
||||
deterministic=deterministic,
|
||||
dropout_rate=self.attention_dropout_rate,
|
||||
num_heads=self.num_heads,
|
||||
# why isn't this true by default???
|
||||
force_fp32_for_softmax=True,
|
||||
)(x, x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
|
||||
x = x + inputs
|
||||
|
||||
# MLP block.
|
||||
y = nn.LayerNorm(dtype=self.dtype)(x)
|
||||
y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
|
||||
y, deterministic=deterministic
|
||||
)
|
||||
|
||||
return x + y, None
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Transformer Model Encoder for sequence to sequence translation.
|
||||
|
||||
Attributes:
|
||||
num_layers: number of layers
|
||||
mlp_dim: dimension of the mlp on top of attention block
|
||||
num_heads: Number of heads in nn.MultiHeadDotProductAttention
|
||||
dropout_rate: dropout rate.
|
||||
attention_dropout_rate: dropout rate in self attention.
|
||||
"""
|
||||
|
||||
dtype: jax.typing.DTypeLike
|
||||
num_layers: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
dropout_rate: float = 0.1
|
||||
attention_dropout_rate: float = 0.1
|
||||
add_position_embedding: bool = True
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, *, train):
|
||||
"""Applies Transformer model on the inputs.
|
||||
|
||||
Args:
|
||||
x: Inputs to the layer.
|
||||
train: Set to `True` when training.
|
||||
|
||||
Returns:
|
||||
output of a transformer encoder.
|
||||
"""
|
||||
assert x.ndim == 3 # (batch, len, emb)
|
||||
|
||||
if self.add_position_embedding:
|
||||
x = AddPositionEmbs(
|
||||
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
|
||||
name="posembed_input",
|
||||
)(x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
|
||||
|
||||
x = x.astype(self.dtype)
|
||||
# Input Encoder
|
||||
block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))
|
||||
x, _ = nn.scan(
|
||||
block,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=nn.broadcast,
|
||||
length=self.num_layers,
|
||||
)(
|
||||
name="encoderblock",
|
||||
mlp_dim=self.mlp_dim,
|
||||
dropout_rate=self.dropout_rate,
|
||||
attention_dropout_rate=self.attention_dropout_rate,
|
||||
dtype=self.dtype,
|
||||
num_heads=self.num_heads,
|
||||
)(x, not train)
|
||||
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""VisionTransformer."""
|
||||
|
||||
dtype: jax.typing.DTypeLike
|
||||
num_classes: int
|
||||
patches: Any
|
||||
transformer: Any
|
||||
hidden_size: int
|
||||
resnet: Any | None = None
|
||||
representation_size: int | None = None
|
||||
classifier: str = "token"
|
||||
head_bias_init: float = 0.0
|
||||
encoder: type[nn.Module] = Encoder
|
||||
model_name: str | None = None
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, *, train):
|
||||
x = inputs
|
||||
# (Possibly partial) ResNet root.
|
||||
if self.resnet is not None:
|
||||
width = int(64 * self.resnet.width_factor)
|
||||
|
||||
# Root block.
|
||||
x = models_resnet.StdConv(
|
||||
features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root"
|
||||
)(x)
|
||||
x = nn.GroupNorm(name="gn_root")(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
|
||||
|
||||
# ResNet stages.
|
||||
if self.resnet.num_layers:
|
||||
x = models_resnet.ResNetStage(
|
||||
block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1"
|
||||
)(x)
|
||||
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
|
||||
x = models_resnet.ResNetStage(
|
||||
block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
|
||||
)(x)
|
||||
|
||||
n, h, w, c = x.shape
|
||||
|
||||
# We can merge s2d+emb into a single conv; it's the same.
|
||||
x = nn.Conv(
|
||||
features=self.hidden_size,
|
||||
kernel_size=self.patches.size,
|
||||
strides=self.patches.size,
|
||||
padding="VALID",
|
||||
name="embedding",
|
||||
)(x)
|
||||
|
||||
# Here, x is a grid of embeddings.
|
||||
|
||||
# (Possibly partial) Transformer.
|
||||
if self.transformer is not None:
|
||||
n, h, w, c = x.shape
|
||||
x = jnp.reshape(x, [n, h * w, c])
|
||||
|
||||
# If we want to add a class token, add it here.
|
||||
if self.classifier in ["token", "token_unpooled"]:
|
||||
cls = self.param("cls", nn.initializers.zeros, (1, 1, c))
|
||||
cls = jnp.tile(cls, [n, 1, 1])
|
||||
x = jnp.concatenate([cls, x], axis=1)
|
||||
|
||||
x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train)
|
||||
|
||||
if self.classifier == "token":
|
||||
x = x[:, 0]
|
||||
elif self.classifier == "gap":
|
||||
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
|
||||
elif self.classifier in ["unpooled", "token_unpooled"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid classifier={self.classifier}")
|
||||
|
||||
if self.representation_size is not None:
|
||||
x = nn.Dense(features=self.representation_size, name="pre_logits")(x)
|
||||
x = nn.tanh(x)
|
||||
else:
|
||||
x = IdentityLayer(name="pre_logits")(x)
|
||||
|
||||
if self.num_classes:
|
||||
x = nn.Dense(
|
||||
features=self.num_classes,
|
||||
name="head",
|
||||
kernel_init=nn.initializers.zeros,
|
||||
bias_init=nn.initializers.constant(self.head_bias_init),
|
||||
)(x)
|
||||
return x
|
||||
@@ -0,0 +1,281 @@
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GemmaForCausalLM
|
||||
from transformers import PaliGemmaForConditionalGeneration
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config,
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
vlm_config_hf.image_token_index = 257152
|
||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.torch_dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.torch_dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
hidden_size=action_expert_config.width,
|
||||
intermediate_size=action_expert_config.mlp_dim,
|
||||
num_attention_heads=action_expert_config.num_heads,
|
||||
num_hidden_layers=action_expert_config.depth,
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
torch_dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
|
||||
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||
if precision == "bfloat16":
|
||||
self.to(dtype=torch.bfloat16)
|
||||
elif precision == "float32":
|
||||
self.to(dtype=torch.float32)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.weight",
|
||||
"vision_tower.vision_model.embeddings.patch_embedding.bias",
|
||||
"vision_tower.vision_model.embeddings.position_embedding.weight",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
]
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_keep_float32):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
adarms_cond: list[torch.Tensor] | None = None,
|
||||
):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
suffix_output = self.gemma_expert.model.forward(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
||||
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
if not self.gemma_expert.model.gradient_checkpointing:
|
||||
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
||||
self.gemma_expert.model.gradient_checkpointing = True
|
||||
use_gradient_checkpointing = True
|
||||
|
||||
# Debug gradient checkpointing status
|
||||
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
|
||||
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
||||
print(f"Model training mode: {self.training}")
|
||||
print(
|
||||
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
|
||||
)
|
||||
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
print(
|
||||
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
|
||||
)
|
||||
self._debug_gc_printed = True
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
|
||||
gates.append(gate)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
self.paligemma.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
|
||||
# first residual
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
|
||||
return outputs_embeds
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
# Old code removed - now using compute_layer_complete function above
|
||||
|
||||
# final norm
|
||||
# Define final norm computation function for gradient checkpointing
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
# Apply gradient checkpointing to final norm if enabled
|
||||
if use_gradient_checkpointing:
|
||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False
|
||||
)
|
||||
else:
|
||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
||||
|
||||
prefix_output = outputs_embeds[0]
|
||||
suffix_output = outputs_embeds[1]
|
||||
prefix_past_key_values = None
|
||||
|
||||
return [prefix_output, suffix_output], prefix_past_key_values
|
||||
@@ -0,0 +1,461 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
import openpi.models.gemma as _gemma
|
||||
from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
||||
import openpi.models_pytorch.preprocessing_pytorch as _preprocessing
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
class PI0Pytorch(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.pi05 = config.pi05
|
||||
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True] if self.pi05 else [False, False],
|
||||
precision=config.dtype,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(32, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, 32)
|
||||
|
||||
if self.pi05:
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
else:
|
||||
self.state_proj = nn.Linear(32, action_expert_config.width)
|
||||
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
||||
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* /home/tianyang/miniconda3/envs/lam3d/lib/python3.11/site-packages/transformers/`."
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def is_gradient_checkpointing_enabled(self):
|
||||
"""Check if gradient checkpointing is enabled."""
|
||||
return self.gradient_checkpointing_enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
||||
|
||||
def _preprocess_observation(self, observation, *, train=True):
|
||||
"""Helper method to preprocess observation."""
|
||||
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
||||
return (
|
||||
list(observation.images.values()),
|
||||
list(observation.image_masks.values()),
|
||||
observation.tokenized_prompt,
|
||||
observation.tokenized_prompt_mask,
|
||||
observation.state,
|
||||
)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
|
||||
def image_embed_func(img):
|
||||
return self.paligemma_with_expert.embed_image(img)
|
||||
|
||||
img_emb = self._apply_checkpoint(image_embed_func, img)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
# Get batch size from the first dimension of the concatenated tensors
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
if not self.pi05:
|
||||
if self.state_proj.weight.dtype == torch.float32:
|
||||
state = state.to(torch.float32)
|
||||
|
||||
# Embed state
|
||||
def state_proj_func(state):
|
||||
return self.state_proj(state)
|
||||
|
||||
state_emb = self._apply_checkpoint(state_proj_func, state)
|
||||
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device
|
||||
)
|
||||
time_emb = time_emb.type(dtype=timestep.dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
def action_proj_func(noisy_actions):
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
if not self.pi05:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
# Apply MLP layers
|
||||
def mlp_func(action_time_emb):
|
||||
x = self.action_time_mlp_in(action_time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
return self.action_time_mlp_out(x)
|
||||
|
||||
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
||||
adarms_cond = None
|
||||
else:
|
||||
# time MLP (for adaRMS)
|
||||
def time_mlp_func(time_emb):
|
||||
x = self.time_mlp_in(time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
x = self.time_mlp_out(x)
|
||||
return F.silu(x)
|
||||
|
||||
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
||||
action_time_emb = action_emb
|
||||
adarms_cond = time_emb
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True)
|
||||
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
# Apply gradient checkpointing if enabled
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
# Apply gradient checkpointing to final action projection if enabled
|
||||
def action_out_proj_func(suffix_out):
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = observation.state.shape[0]
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step - use new tensor assignment instead of in-place operation
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
@@ -0,0 +1,173 @@
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from openpi.shared import image_tools
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
# Constants moved from model.py
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
def preprocess_observation_pytorch(
|
||||
observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
):
|
||||
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
||||
|
||||
This function avoids complex type annotations that can cause torch.compile issues.
|
||||
"""
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
|
||||
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
# Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
if is_channels_first:
|
||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad_torch(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = image.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
||||
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
image = torch.nn.functional.interpolate(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=image.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
image = torch.nn.functional.grid_sample(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3
|
||||
image = image * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4
|
||||
mean = image.mean(dim=[1, 2, 3], keepdim=True)
|
||||
image = (image - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5
|
||||
gray = image.mean(dim=-1, keepdim=True)
|
||||
image = gray + (image - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
|
||||
# Back to [-1, 1]
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
|
||||
else:
|
||||
out_masks[key] = observation.image_masks[key]
|
||||
|
||||
# Create a simple object with the required attributes instead of using the complex Observation class
|
||||
class SimpleProcessedObservation:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
return SimpleProcessedObservation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
@@ -0,0 +1,173 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GemmaModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||
hidden_activation (`str` or `function`, *optional*):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
use_adarms (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use ADARMS.
|
||||
adarms_cond_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the ADARMS condition.
|
||||
```python
|
||||
>>> from transformers import GemmaModel, GemmaConfig
|
||||
>>> # Initializing a Gemma gemma-7b style configuration
|
||||
>>> configuration = GemmaConfig()
|
||||
>>> # Initializing a model from the gemma-7b style configuration
|
||||
>>> model = GemmaModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation=None,
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
use_adarms: bool = False,
|
||||
adarms_cond_dim: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.use_adarms = use_adarms
|
||||
self.adarms_cond_dim = adarms_cond_dim
|
||||
|
||||
# Set default for adarms_cond_dim if use_adarms is True
|
||||
if self.use_adarms and self.adarms_cond_dim is None:
|
||||
self.adarms_cond_dim = self.hidden_size
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["GemmaConfig"]
|
||||
@@ -0,0 +1,862 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
|
||||
# Dense layer for adaptive normalization (if cond_dim is provided)
|
||||
if cond_dim is not None:
|
||||
#self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
# Initialize with zeros (matches source implementation)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
normed_inputs = self._norm(x)
|
||||
|
||||
if cond is None or self.dense is None:
|
||||
# regular RMSNorm
|
||||
# scale by learned parameter in float32 (matches source implementation)
|
||||
normed_inputs = normed_inputs * (1.0 + self.weight.float())
|
||||
return normed_inputs.to(dtype), None # return in original dtype with None gate
|
||||
|
||||
# adaptive RMSNorm (if cond is provided and dense layer exists)
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
|
||||
|
||||
#self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
|
||||
modulation = self.dense(cond)
|
||||
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
|
||||
if len(x.shape) == 3: # [batch, seq, features]
|
||||
modulation = modulation.unsqueeze(1)
|
||||
|
||||
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
|
||||
|
||||
# Apply adaptive normalization: use model weight dtype to ensure compatibility
|
||||
# model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
|
||||
# scale = scale.to(model_dtype)
|
||||
# shift = shift.to(model_dtype)
|
||||
# gate = gate.to(model_dtype)
|
||||
# normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
|
||||
|
||||
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
|
||||
|
||||
return normed_inputs.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
if self.dense is not None:
|
||||
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
|
||||
return repr_str
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: GemmaConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _gated_residual(x, y, gate):
|
||||
"""
|
||||
Applies gated residual connection with optional gate parameter.
|
||||
|
||||
Args:
|
||||
x: Input tensor (residual)
|
||||
y: Output tensor to be added
|
||||
gate: Optional gate tensor to modulate the addition
|
||||
|
||||
Returns:
|
||||
x + y if gate is None, otherwise x + y * gate
|
||||
"""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Use cache if provided
|
||||
if past_key_value is not None:
|
||||
if use_cache:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
else:
|
||||
key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
adarms_cond: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = GemmaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GemmaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GemmaRMSNorm):
|
||||
if hasattr(module, 'weight'):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
adarms_cond: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
#hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
adarms_cond: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Gemma Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
"""
|
||||
)
|
||||
class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
adarms_cond: Optional[torch.Tensor] = None,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
transformer_outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
adarms_cond: Optional[torch.Tensor] = None,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GemmaModel",
|
||||
"GemmaForCausalLM",
|
||||
"GemmaForSequenceClassification",
|
||||
"GemmaForTokenClassification",
|
||||
"GemmaPreTrainedModel",
|
||||
]
|
||||
@@ -0,0 +1,622 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch PaliGemmamodel."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
from ..auto import AutoModel
|
||||
from .configuration_paligemma import PaliGemmaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Paligemma outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PaliGemma causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear(image_features)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = PaliGemmaConfig
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PaliGemmaMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
||||
# inference and fine-tuning
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
language_model = AutoModel.from_config(config=config.text_config)
|
||||
self.language_model = language_model
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.post_init()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
is_training: Optional[bool] = None,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
is_training = is_training if is_training is not None else self.training
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
||||
)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# First unmask prefix tokens during training
|
||||
if is_training:
|
||||
if token_type_ids is None:
|
||||
raise ValueError("Token type ids must be provided during training")
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
|
||||
# Then apply padding mask (will mask pad tokens)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[tuple, PaligemmaModelOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return PaligemmaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.model.get_image_features(pixel_values)
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return PaliGemmaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self.model._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]
|
||||
@@ -0,0 +1,4 @@
|
||||
import transformers
|
||||
|
||||
def check_whether_transformers_replace_is_installed_correctly():
|
||||
return transformers.__version__ == "4.53.2"
|
||||
File diff suppressed because it is too large
Load Diff
202
policy/openpi-InternData-A1/src/openpi/policies/aloha_policy.py
Normal file
202
policy/openpi-InternData-A1/src/openpi/policies/aloha_policy.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
|
||||
|
||||
def make_aloha_example() -> dict:
|
||||
"""Creates a random input example for the Aloha policy."""
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AlohaInputs(transforms.DataTransformFn):
|
||||
"""Inputs for the Aloha policy.
|
||||
|
||||
Expected inputs:
|
||||
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
|
||||
- state: [14]
|
||||
- actions: [action_horizon, 14]
|
||||
"""
|
||||
|
||||
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
|
||||
# replaced with black images and the corresponding `image_mask` will be set to False.
|
||||
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
|
||||
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.True_,
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.True_
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.False_
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": data["state"],
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
|
||||
inputs["actions"] = actions
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AlohaOutputs(transforms.DataTransformFn):
|
||||
"""Outputs for the Aloha policy."""
|
||||
|
||||
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 14 dims.
|
||||
actions = np.asarray(data["actions"][:, :14])
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
|
||||
|
||||
|
||||
def _joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def _normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def _unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def _gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
|
||||
# There are 4096 total encoder counts and aloha uses a zero of 2048.
|
||||
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
|
||||
return _normalize(value, min_val=0.5476, max_val=1.6296)
|
||||
|
||||
|
||||
def _gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# We do not scale the output since the trossen model predictions are already in radians.
|
||||
# See the comment in _gripper_to_angular for a derivation of the constant
|
||||
value = value + 0.5476
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return _normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def _gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return value - 0.5476
|
||||
|
||||
|
||||
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
# state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
|
||||
# dim sizes: [6, 1, 6, 1]
|
||||
state = np.asarray(data["state"])
|
||||
state = _decode_state(state, adapt_to_pi=adapt_to_pi)
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [channel, height, width] to [height, width, channel].
|
||||
return einops.rearrange(img, "c h w -> h w c")
|
||||
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
data["state"] = state
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
# Flip the joints.
|
||||
state = _joint_flip_mask() * state
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
state[[6, 13]] = _gripper_to_angular(state[[6, 13]])
|
||||
return state
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
# Flip the joints.
|
||||
actions = _joint_flip_mask() * actions
|
||||
actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]])
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
actions = _joint_flip_mask() * actions
|
||||
actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]])
|
||||
return actions
|
||||
@@ -0,0 +1,81 @@
|
||||
import dataclasses
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from openpi.models import model as _model
|
||||
|
||||
|
||||
def make_droid_example() -> dict:
|
||||
"""Creates a random input example for the Droid policy."""
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _parse_image(image) -> np.ndarray:
|
||||
image = np.asarray(image)
|
||||
if np.issubdtype(image.dtype, np.floating):
|
||||
image = (255 * image).astype(np.uint8)
|
||||
if image.shape[0] == 3:
|
||||
image = einops.rearrange(image, "c h w -> h w c")
|
||||
return image
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DroidInputs(transforms.DataTransformFn):
|
||||
# Determines which model will be used.
|
||||
model_type: _model.ModelType
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
gripper_pos = np.asarray(data["observation/gripper_position"])
|
||||
if gripper_pos.ndim == 0:
|
||||
# Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions
|
||||
gripper_pos = gripper_pos[np.newaxis]
|
||||
state = np.concatenate([data["observation/joint_position"], gripper_pos])
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference
|
||||
base_image = _parse_image(data["observation/exterior_image_1_left"])
|
||||
wrist_image = _parse_image(data["observation/wrist_image_left"])
|
||||
|
||||
match self.model_type:
|
||||
case _model.ModelType.PI0 | _model.ModelType.PI05:
|
||||
names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
||||
images = (base_image, wrist_image, np.zeros_like(base_image))
|
||||
image_masks = (np.True_, np.True_, np.False_)
|
||||
case _model.ModelType.PI0_FAST:
|
||||
names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb")
|
||||
# We don't mask out padding images for FAST models.
|
||||
images = (base_image, np.zeros_like(base_image), wrist_image)
|
||||
image_masks = (np.True_, np.True_, np.True_)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported model type: {self.model_type}")
|
||||
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": dict(zip(names, images, strict=True)),
|
||||
"image_mask": dict(zip(names, image_masks, strict=True)),
|
||||
}
|
||||
|
||||
if "actions" in data:
|
||||
inputs["actions"] = np.asarray(data["actions"])
|
||||
|
||||
if "prompt" in data:
|
||||
if isinstance(data["prompt"], bytes):
|
||||
data["prompt"] = data["prompt"].decode("utf-8")
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DroidOutputs(transforms.DataTransformFn):
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 8 dims.
|
||||
return {"actions": np.asarray(data["actions"][:, :8])}
|
||||
100
policy/openpi-InternData-A1/src/openpi/policies/libero_policy.py
Normal file
100
policy/openpi-InternData-A1/src/openpi/policies/libero_policy.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import dataclasses
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from openpi.models import model as _model
|
||||
|
||||
|
||||
def make_libero_example() -> dict:
|
||||
"""Creates a random input example for the Libero policy."""
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _parse_image(image) -> np.ndarray:
|
||||
image = np.asarray(image)
|
||||
if np.issubdtype(image.dtype, np.floating):
|
||||
image = (255 * image).astype(np.uint8)
|
||||
if image.shape[0] == 3:
|
||||
image = einops.rearrange(image, "c h w -> h w c")
|
||||
return image
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LiberoInputs(transforms.DataTransformFn):
|
||||
"""
|
||||
This class is used to convert inputs to the model to the expected format. It is used for both training and inference.
|
||||
|
||||
For your own dataset, you can copy this class and modify the keys based on the comments below to pipe
|
||||
the correct elements of your dataset into the model.
|
||||
"""
|
||||
|
||||
# Determines which model will be used.
|
||||
# Do not change this for your own dataset.
|
||||
model_type: _model.ModelType
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference.
|
||||
# Keep this for your own dataset, but if your dataset stores the images
|
||||
# in a different key than "observation/image" or "observation/wrist_image",
|
||||
# you should change it below.
|
||||
# Pi0 models support three image inputs at the moment: one third-person view,
|
||||
# and two wrist views (left and right). If your dataset does not have a particular type
|
||||
# of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the
|
||||
# right wrist image below.
|
||||
base_image = _parse_image(data["observation/image"])
|
||||
wrist_image = _parse_image(data["observation/wrist_image"])
|
||||
|
||||
# Create inputs dict. Do not change the keys in the dict below.
|
||||
inputs = {
|
||||
"state": data["observation/state"],
|
||||
"image": {
|
||||
"base_0_rgb": base_image,
|
||||
"left_wrist_0_rgb": wrist_image,
|
||||
# Pad any non-existent images with zero-arrays of the appropriate shape.
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.True_,
|
||||
"left_wrist_0_rgb": np.True_,
|
||||
# We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset.
|
||||
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
|
||||
},
|
||||
}
|
||||
|
||||
# Pad actions to the model action dimension. Keep this for your own dataset.
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
inputs["actions"] = data["actions"]
|
||||
|
||||
# Pass the prompt (aka language instruction) to the model.
|
||||
# Keep this for your own dataset (but modify the key if the instruction is not
|
||||
# stored in "prompt"; the output dict always needs to have the key "prompt").
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LiberoOutputs(transforms.DataTransformFn):
|
||||
"""
|
||||
This class is used to convert outputs from the model back the the dataset specific format. It is
|
||||
used for inference only.
|
||||
|
||||
For your own dataset, you can copy this class and modify the action dimension based on the comments below.
|
||||
"""
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first N actions -- since we padded actions above to fit the model action
|
||||
# dimension, we need to now parse out the correct number of actions in the return dict.
|
||||
# For Libero, we only return the first 7 actions (since the rest is padding).
|
||||
# For your own dataset, replace `7` with the action dimension of your dataset.
|
||||
return {"actions": np.asarray(data["actions"][:, :7])}
|
||||
135
policy/openpi-InternData-A1/src/openpi/policies/policy.py
Normal file
135
policy/openpi-InternData-A1/src/openpi/policies/policy.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import flax
|
||||
import flax.traverse_util
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from openpi_client import base_policy as _base_policy
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi import transforms as _transforms
|
||||
from openpi.models import model as _model
|
||||
from openpi.shared import array_typing as at
|
||||
from openpi.shared import nnx_utils
|
||||
|
||||
BasePolicy: TypeAlias = _base_policy.BasePolicy
|
||||
|
||||
|
||||
class Policy(BasePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
model: _model.BaseModel,
|
||||
*,
|
||||
rng: at.KeyArrayLike | None = None,
|
||||
transforms: Sequence[_transforms.DataTransformFn] = (),
|
||||
output_transforms: Sequence[_transforms.DataTransformFn] = (),
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
pytorch_device: str = "cpu",
|
||||
is_pytorch: bool = False,
|
||||
):
|
||||
"""Initialize the Policy.
|
||||
|
||||
Args:
|
||||
model: The model to use for action sampling.
|
||||
rng: Random number generator key for JAX models. Ignored for PyTorch models.
|
||||
transforms: Input data transformations to apply before inference.
|
||||
output_transforms: Output data transformations to apply after inference.
|
||||
sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
|
||||
metadata: Additional metadata to store with the policy.
|
||||
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
|
||||
Only relevant when is_pytorch=True.
|
||||
is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
|
||||
"""
|
||||
self._model = model
|
||||
self._input_transform = _transforms.compose(transforms)
|
||||
self._output_transform = _transforms.compose(output_transforms)
|
||||
self._sample_kwargs = sample_kwargs or {}
|
||||
self._metadata = metadata or {}
|
||||
self._is_pytorch_model = is_pytorch
|
||||
self._pytorch_device = pytorch_device
|
||||
|
||||
if self._is_pytorch_model:
|
||||
self._model = self._model.to(pytorch_device)
|
||||
self._model.eval()
|
||||
self._sample_actions = model.sample_actions
|
||||
else:
|
||||
# JAX model setup
|
||||
self._sample_actions = nnx_utils.module_jit(model.sample_actions)
|
||||
self._rng = rng or jax.random.key(0)
|
||||
|
||||
@override
|
||||
def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: ignore[misc]
|
||||
# Make a copy since transformations may modify the inputs in place.
|
||||
inputs = jax.tree.map(lambda x: x, obs)
|
||||
inputs = self._input_transform(inputs)
|
||||
if not self._is_pytorch_model:
|
||||
# Make a batch and convert to jax.Array.
|
||||
inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)
|
||||
self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng)
|
||||
else:
|
||||
# Convert inputs to PyTorch tensors and move to correct device
|
||||
inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs)
|
||||
sample_rng_or_pytorch_device = self._pytorch_device
|
||||
|
||||
# Prepare kwargs for sample_actions
|
||||
sample_kwargs = dict(self._sample_kwargs)
|
||||
if noise is not None:
|
||||
noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)
|
||||
|
||||
if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension
|
||||
noise = noise[None, ...] # Make it (1, action_horizon, action_dim)
|
||||
sample_kwargs["noise"] = noise
|
||||
|
||||
observation = _model.Observation.from_dict(inputs)
|
||||
start_time = time.monotonic()
|
||||
outputs = {
|
||||
"state": inputs["state"],
|
||||
"actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs),
|
||||
}
|
||||
model_time = time.monotonic() - start_time
|
||||
if self._is_pytorch_model:
|
||||
outputs = jax.tree.map(lambda x: np.asarray(x[0, ...].detach().cpu()), outputs)
|
||||
else:
|
||||
outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs)
|
||||
|
||||
outputs = self._output_transform(outputs)
|
||||
outputs["policy_timing"] = {
|
||||
"infer_ms": model_time * 1000,
|
||||
}
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
return self._metadata
|
||||
|
||||
|
||||
class PolicyRecorder(_base_policy.BasePolicy):
|
||||
"""Records the policy's behavior to disk."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
|
||||
self._policy = policy
|
||||
|
||||
logging.info(f"Dumping policy records to: {record_dir}")
|
||||
self._record_dir = pathlib.Path(record_dir)
|
||||
self._record_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._record_step = 0
|
||||
|
||||
@override
|
||||
def infer(self, obs: dict) -> dict: # type: ignore[misc]
|
||||
results = self._policy.infer(obs)
|
||||
|
||||
data = {"inputs": obs, "outputs": results}
|
||||
data = flax.traverse_util.flatten_dict(data, sep="/")
|
||||
|
||||
output_path = self._record_dir / f"step_{self._record_step}"
|
||||
self._record_step += 1
|
||||
|
||||
np.save(output_path, np.asarray(data))
|
||||
return results
|
||||
@@ -0,0 +1,94 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.policies.policy as _policy
|
||||
import openpi.shared.download as download
|
||||
from openpi.training import checkpoints as _checkpoints
|
||||
from openpi.training import config as _config
|
||||
import openpi.transforms as transforms
|
||||
|
||||
|
||||
def create_trained_policy(
|
||||
train_config: _config.TrainConfig,
|
||||
checkpoint_dir: pathlib.Path | str,
|
||||
*,
|
||||
repack_transforms: transforms.Group | None = None,
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
default_prompt: str | None = None,
|
||||
norm_stats: dict[str, transforms.NormStats] | None = None,
|
||||
pytorch_device: str | None = None,
|
||||
) -> _policy.Policy:
|
||||
"""Create a policy from a trained checkpoint.
|
||||
|
||||
Args:
|
||||
train_config: The training config to use to create the model.
|
||||
checkpoint_dir: The directory to load the model from.
|
||||
repack_transforms: Optional transforms that will be applied before any other transforms.
|
||||
sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
|
||||
kwargs will be used.
|
||||
default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
|
||||
data if it doesn't already exist.
|
||||
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
|
||||
from the checkpoint directory.
|
||||
pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
|
||||
If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
|
||||
|
||||
Note:
|
||||
The function automatically detects whether the model is PyTorch-based by checking for the
|
||||
presence of "model.safensors" in the checkpoint directory.
|
||||
"""
|
||||
repack_transforms = repack_transforms or transforms.Group()
|
||||
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
|
||||
|
||||
# Check if this is a PyTorch model by looking for model.safetensors
|
||||
weight_path = os.path.join(checkpoint_dir, "model.safetensors")
|
||||
is_pytorch = os.path.exists(weight_path)
|
||||
|
||||
logging.info("Loading model...")
|
||||
if is_pytorch:
|
||||
model = train_config.model.load_pytorch(train_config, weight_path)
|
||||
model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
|
||||
else:
|
||||
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
|
||||
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
|
||||
if norm_stats is None:
|
||||
# We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
|
||||
# that the policy is using the same normalization stats as the original training process.
|
||||
if data_config.asset_id is None:
|
||||
raise ValueError("Asset id is required to load norm stats.")
|
||||
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
|
||||
|
||||
# Determine the device to use for PyTorch models
|
||||
if is_pytorch and pytorch_device is None:
|
||||
try:
|
||||
import torch
|
||||
|
||||
pytorch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except ImportError:
|
||||
pytorch_device = "cpu"
|
||||
|
||||
return _policy.Policy(
|
||||
model,
|
||||
transforms=[
|
||||
*repack_transforms.inputs,
|
||||
transforms.InjectDefaultPrompt(default_prompt),
|
||||
*data_config.data_transforms.inputs,
|
||||
transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
output_transforms=[
|
||||
*data_config.model_transforms.outputs,
|
||||
transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.data_transforms.outputs,
|
||||
*repack_transforms.outputs,
|
||||
],
|
||||
sample_kwargs=sample_kwargs,
|
||||
metadata=train_config.policy_metadata,
|
||||
is_pytorch=is_pytorch,
|
||||
pytorch_device=pytorch_device if is_pytorch else None,
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
from openpi_client import action_chunk_broker
|
||||
import pytest
|
||||
|
||||
from openpi.policies import aloha_policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_infer():
|
||||
config = _config.get_config("pi0_aloha_sim")
|
||||
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
|
||||
|
||||
example = aloha_policy.make_aloha_example()
|
||||
result = policy.infer(example)
|
||||
|
||||
assert result["actions"].shape == (config.model.action_horizon, 14)
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_broker():
|
||||
config = _config.get_config("pi0_aloha_sim")
|
||||
policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
|
||||
|
||||
broker = action_chunk_broker.ActionChunkBroker(
|
||||
policy,
|
||||
# Only execute the first half of the chunk.
|
||||
action_horizon=config.model.action_horizon // 2,
|
||||
)
|
||||
|
||||
example = aloha_policy.make_aloha_example()
|
||||
for _ in range(config.model.action_horizon):
|
||||
outputs = broker.infer(example)
|
||||
assert outputs["actions"].shape == (14,)
|
||||
@@ -0,0 +1,221 @@
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from pdb import set_trace
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Reala2dInputs(transforms.DataTransformFn):
|
||||
"""Inputs for the A2D policy.
|
||||
"""
|
||||
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
|
||||
# replaced with black images and the corresponding `image_mask` will be set to False.
|
||||
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_a2d(data, adapt_to_pi=self.adapt_to_pi)
|
||||
if "images" in data:
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.True_,
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.True_
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.False_
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": data["state"],
|
||||
}
|
||||
else:
|
||||
inputs={
|
||||
"state": data["state"],
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
|
||||
inputs["actions"] = actions
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Reala2dOutputs(transforms.DataTransformFn):
|
||||
"""Outputs for the a2d policy."""
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 16 dims.
|
||||
actions = np.asarray(data["actions"][:, :16])
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
|
||||
|
||||
|
||||
def _joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def _normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def _unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def _gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
|
||||
# There are 4096 total encoder counts and aloha uses a zero of 2048.
|
||||
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
|
||||
return _normalize(value, min_val=0.5476, max_val=1.6296)
|
||||
|
||||
|
||||
def _gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# We do not scale the output since the trossen model predictions are already in radians.
|
||||
# See the comment in _gripper_to_angular for a derivation of the constant
|
||||
value = value + 0.5476
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return _normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def _gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return value - 0.5476
|
||||
|
||||
|
||||
def _decode_a2d(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
state_dict = data["state_dict"]
|
||||
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["state_dict"]
|
||||
action_dict = data["action_dict"]
|
||||
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["action_dict"]
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [channel, height, width] to [height, width, channel].
|
||||
return einops.rearrange(img, "c h w -> h w c")
|
||||
if "images" in data:
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
joint = state["joint"]
|
||||
gripper = state["gripper"]
|
||||
state_left_arm_gripper = np.concatenate(
|
||||
[
|
||||
joint[:7],
|
||||
gripper[:1],
|
||||
],
|
||||
axis=-1
|
||||
)
|
||||
state_right_arm_gripper = np.concatenate(
|
||||
[
|
||||
joint[7:14],
|
||||
gripper[1:2],
|
||||
],
|
||||
axis=-1
|
||||
)
|
||||
state = np.concatenate(
|
||||
[
|
||||
state_left_arm_gripper,
|
||||
state_right_arm_gripper,
|
||||
],
|
||||
axis=-1
|
||||
)
|
||||
return state
|
||||
|
||||
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
joint = action["joint"]
|
||||
gripper = action["gripper"]
|
||||
action_left_arm_gripper = np.concatenate(
|
||||
[
|
||||
joint[:,:7],
|
||||
gripper[:,:1],
|
||||
],
|
||||
axis=-1
|
||||
)
|
||||
action_right_arm_gripper = np.concatenate(
|
||||
[
|
||||
joint[:,7:14],
|
||||
gripper[:,1:2],
|
||||
],
|
||||
axis=-1
|
||||
)
|
||||
action = np.concatenate(
|
||||
[
|
||||
action_left_arm_gripper,
|
||||
action_right_arm_gripper,
|
||||
],
|
||||
axis=-1
|
||||
)
|
||||
return action
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
@@ -0,0 +1,205 @@
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from pdb import set_trace
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RealLift2Inputs(transforms.DataTransformFn):
|
||||
"""Inputs for the Lift2 policy.
|
||||
"""
|
||||
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
|
||||
# replaced with black images and the corresponding `image_mask` will be set to False.
|
||||
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
|
||||
if "images" in data:
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.True_,
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.True_
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.False_
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": data["state"],
|
||||
}
|
||||
else:
|
||||
inputs={
|
||||
"state": data["state"],
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
|
||||
inputs["actions"] = actions
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RealLift2Outputs(transforms.DataTransformFn):
|
||||
"""Outputs for the Lift2 policy."""
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 14 dims.
|
||||
actions = np.asarray(data["actions"][:, :14])
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
|
||||
|
||||
|
||||
def _joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def _normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def _unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def _gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
|
||||
# There are 4096 total encoder counts and aloha uses a zero of 2048.
|
||||
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
|
||||
return _normalize(value, min_val=0.5476, max_val=1.6296)
|
||||
|
||||
|
||||
def _gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# We do not scale the output since the trossen model predictions are already in radians.
|
||||
# See the comment in _gripper_to_angular for a derivation of the constant
|
||||
value = value + 0.5476
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return _normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def _gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return value - 0.5476
|
||||
|
||||
|
||||
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
# state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
|
||||
# dim sizes: [6, 1, 6, 1]
|
||||
state_dict = data["state_dict"]
|
||||
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["state_dict"]
|
||||
action_dict = data["action_dict"]
|
||||
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["action_dict"]
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [channel, height, width] to [height, width, channel].
|
||||
return einops.rearrange(img, "c h w -> h w c")
|
||||
if "images" in data:
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
state_left_arm = state["left_joint"]
|
||||
state_left_gripper = state["left_gripper"]
|
||||
state_right_arm = state["right_joint"]
|
||||
state_right_gripper = state["right_gripper"]
|
||||
if state_left_arm.ndim - state_left_gripper.ndim == 1:
|
||||
if state_left_gripper.ndim == 0:
|
||||
state_left_gripper = state_left_gripper[None]
|
||||
state_right_gripper = state_right_gripper[None]
|
||||
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=0)
|
||||
elif state_left_gripper.ndim == 1:
|
||||
state_left_gripper = state_left_gripper[:, None]
|
||||
state_right_gripper = state_right_gripper[:, None]
|
||||
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=1)
|
||||
return state
|
||||
|
||||
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
action_left_arm = action["left_joint"]
|
||||
action_left_gripper = action["left_gripper"]
|
||||
action_right_arm = action["right_joint"]
|
||||
action_right_gripper = action["right_gripper"]
|
||||
|
||||
if action_left_arm.ndim - action_left_gripper.ndim == 1:
|
||||
if action_left_gripper.ndim == 0:
|
||||
action_left_gripper = action_left_gripper[None]
|
||||
action_right_gripper = action_right_gripper[None]
|
||||
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=0)
|
||||
elif action_left_gripper.ndim == 1:
|
||||
action_left_gripper = action_left_gripper[:, None]
|
||||
action_right_gripper = action_right_gripper[:, None]
|
||||
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=1)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
@@ -0,0 +1,207 @@
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from pdb import set_trace
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Sim2RealSplitAlohaInputs(transforms.DataTransformFn):
|
||||
"""Inputs for the Split Aloha policy.
|
||||
"""
|
||||
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
|
||||
# replaced with black images and the corresponding `image_mask` will be set to False.
|
||||
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
|
||||
if "images" in data:
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.True_,
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.True_
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.False_
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": data["state"],
|
||||
}
|
||||
else:
|
||||
inputs = {
|
||||
"state": data["state"],
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
|
||||
inputs["actions"] = actions
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
# set_trace()
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Sim2RealSplitAlohaOutputs(transforms.DataTransformFn):
|
||||
"""Outputs for the Split Aloha policy."""
|
||||
|
||||
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 14 dims.
|
||||
actions = np.asarray(data["actions"][:, :14])
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
|
||||
|
||||
|
||||
def _joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def _normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def _unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def _gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
|
||||
# There are 4096 total encoder counts and aloha uses a zero of 2048.
|
||||
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
|
||||
return _normalize(value, min_val=0.5476, max_val=1.6296)
|
||||
|
||||
|
||||
def _gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# We do not scale the output since the trossen model predictions are already in radians.
|
||||
# See the comment in _gripper_to_angular for a derivation of the constant
|
||||
value = value + 0.5476
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return _normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def _gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return value - 0.5476
|
||||
|
||||
|
||||
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
state_dict = data["state_dict"]
|
||||
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["state_dict"]
|
||||
action_dict = data["action_dict"]
|
||||
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["action_dict"]
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [channel, height, width] to [height, width, channel].
|
||||
return einops.rearrange(img, "c h w -> h w c")
|
||||
if "images" in data:
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
state_left_arm = state["left_joint"]
|
||||
state_left_gripper = state["left_gripper"]
|
||||
state_right_arm = state["right_joint"]
|
||||
state_right_gripper = state["right_gripper"]
|
||||
if state_left_arm.ndim - state_left_gripper.ndim == 1:
|
||||
if state_left_gripper.ndim == 0:
|
||||
state_left_gripper = np.array([0])
|
||||
state_right_gripper = np.array([0])
|
||||
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=0)
|
||||
elif state_left_gripper.ndim == 1:
|
||||
state_left_gripper = np.array([[0]])
|
||||
state_right_gripper = np.array([[0]])
|
||||
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=1)
|
||||
return state
|
||||
|
||||
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
action_left_arm = action["left_joint"]
|
||||
action_left_gripper = action["left_gripper_openness"]
|
||||
action_right_arm = action["right_joint"]
|
||||
action_right_gripper = action["right_gripper_openness"]
|
||||
|
||||
if action_left_arm.ndim - action_left_gripper.ndim == 1:
|
||||
if action_left_gripper.ndim == 0:
|
||||
action_left_gripper = action_left_gripper[None]
|
||||
action_right_gripper = action_right_gripper[None]
|
||||
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=0)
|
||||
elif action_left_gripper.ndim == 1:
|
||||
action_left_gripper = action_left_gripper[:, None]
|
||||
action_right_gripper = action_right_gripper[:, None]
|
||||
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=1)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
@@ -0,0 +1,185 @@
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from pdb import set_trace
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SimFrankaInputs(transforms.DataTransformFn):
|
||||
"""Inputs for the Franka policy.
|
||||
"""
|
||||
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
|
||||
# replaced with black images and the corresponding `image_mask` will be set to False.
|
||||
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_franka(data, adapt_to_pi=self.adapt_to_pi)
|
||||
if "images" in data:
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.True_,
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.True_
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.False_
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": data["state"],
|
||||
"pose": data["pose"],
|
||||
}
|
||||
else:
|
||||
inputs = {
|
||||
"state": data["state"],
|
||||
"pose": data["pose"],
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
|
||||
inputs["actions"] = actions
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SimFrankaOutputs(transforms.DataTransformFn):
|
||||
"""Outputs for the Lift2 policy."""
|
||||
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 7 dims.
|
||||
actions = np.asarray(data["actions"][:, :7])
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
|
||||
|
||||
|
||||
def _joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def _normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def _unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def _gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
|
||||
# There are 4096 total encoder counts and aloha uses a zero of 2048.
|
||||
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
|
||||
return _normalize(value, min_val=0.5476, max_val=1.6296)
|
||||
|
||||
|
||||
def _gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# We do not scale the output since the trossen model predictions are already in radians.
|
||||
# See the comment in _gripper_to_angular for a derivation of the constant
|
||||
value = value + 0.5476
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return _normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def _gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return value - 0.5476
|
||||
|
||||
|
||||
def _decode_franka(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
state_dict = data["state_dict"]
|
||||
data["state"], data["pose"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["state_dict"]
|
||||
action_dict = data["action_dict"]
|
||||
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["action_dict"]
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [channel, height, width] to [height, width, channel].
|
||||
return einops.rearrange(img, "c h w -> h w c")
|
||||
if "images" in data:
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
gripper_position = state["gripper_position"][None]
|
||||
gripper_pose = state["gripper_pose"]
|
||||
joint_position = state["joint_position"]
|
||||
state = np.concatenate([joint_position, gripper_position], axis=0)
|
||||
pose = np.concatenate([gripper_pose, gripper_position], axis=0)
|
||||
return state, pose
|
||||
|
||||
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
gripper_pose = action["gripper_pose"]
|
||||
gripper_openness = action["gripper_openness"][..., None]
|
||||
action = np.concatenate([gripper_pose, gripper_openness], axis=1)
|
||||
return action
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
@@ -0,0 +1,208 @@
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from pdb import set_trace
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SimSplitAlohaInputs(transforms.DataTransformFn):
|
||||
"""Inputs for the Split Aloha policy.
|
||||
"""
|
||||
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
|
||||
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
|
||||
# replaced with black images and the corresponding `image_mask` will be set to False.
|
||||
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
|
||||
if "images" in data:
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.True_,
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.True_
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.False_
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": data["state"],
|
||||
}
|
||||
else:
|
||||
inputs = {
|
||||
"state": data["state"],
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
|
||||
inputs["actions"] = actions
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SimSplitAlohaOutputs(transforms.DataTransformFn):
|
||||
"""Outputs for the Split Aloha policy."""
|
||||
|
||||
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 14 dims.
|
||||
actions = np.asarray(data["actions"][:, :14])
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
|
||||
|
||||
|
||||
def _joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def _normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def _unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def _gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
|
||||
# There are 4096 total encoder counts and aloha uses a zero of 2048.
|
||||
# Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
|
||||
return _normalize(value, min_val=0.5476, max_val=1.6296)
|
||||
|
||||
|
||||
def _gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# We do not scale the output since the trossen model predictions are already in radians.
|
||||
# See the comment in _gripper_to_angular for a derivation of the constant
|
||||
value = value + 0.5476
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return _normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def _gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return value - 0.5476
|
||||
|
||||
|
||||
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
# state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
|
||||
# dim sizes: [7, 1, 7, 1]
|
||||
state_dict = data["state_dict"]
|
||||
data["state"] = _decode_state(state_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["state_dict"]
|
||||
action_dict = data["action_dict"]
|
||||
data["actions"] = _decode_action(action_dict, adapt_to_pi=adapt_to_pi)
|
||||
del data["action_dict"]
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [channel, height, width] to [height, width, channel].
|
||||
return einops.rearrange(img, "c h w -> h w c")
|
||||
if "images" in data:
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
state_left_arm = state["left_joint"]
|
||||
state_left_gripper = state["left_gripper"]
|
||||
state_right_arm = state["right_joint"]
|
||||
state_right_gripper = state["right_gripper"]
|
||||
if state_left_arm.ndim - state_left_gripper.ndim == 1:
|
||||
if state_left_gripper.ndim == 0:
|
||||
state_left_gripper = state_left_gripper[None]
|
||||
state_right_gripper = state_right_gripper[None]
|
||||
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=0)
|
||||
elif state_left_gripper.ndim == 1:
|
||||
state_left_gripper = state_left_gripper[:, None]
|
||||
state_right_gripper = state_right_gripper[:, None]
|
||||
state = np.concatenate([state_left_arm, state_left_gripper, state_right_arm, state_right_gripper], axis=1)
|
||||
return state
|
||||
|
||||
def _decode_action(action, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
action_left_arm = action["left_joint"]
|
||||
action_left_gripper = action["left_gripper_openness"]
|
||||
action_right_arm = action["right_joint"]
|
||||
action_right_gripper = action["right_gripper_openness"]
|
||||
|
||||
if action_left_arm.ndim - action_left_gripper.ndim == 1:
|
||||
if action_left_gripper.ndim == 0:
|
||||
action_left_gripper = action_left_gripper[None]
|
||||
action_right_gripper = action_right_gripper[None]
|
||||
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=0)
|
||||
elif action_left_gripper.ndim == 1:
|
||||
action_left_gripper = action_left_gripper[:, None]
|
||||
action_right_gripper = action_right_gripper[:, None]
|
||||
action = np.concatenate([action_left_arm, action_left_gripper, action_right_arm, action_right_gripper], axis=1)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
return actions
|
||||
0
policy/openpi-InternData-A1/src/openpi/py.typed
Normal file
0
policy/openpi-InternData-A1/src/openpi/py.typed
Normal file
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import http
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
import websockets.asyncio.server as _server
|
||||
import websockets.frames
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketPolicyServer:
|
||||
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
|
||||
|
||||
Currently only implements the `load` and `infer` methods.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: _base_policy.BasePolicy,
|
||||
host: str = "0.0.0.0",
|
||||
port: int | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._metadata = metadata or {}
|
||||
logging.getLogger("websockets.server").setLevel(logging.INFO)
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
asyncio.run(self.run())
|
||||
|
||||
async def run(self):
|
||||
async with _server.serve(
|
||||
self._handler,
|
||||
self._host,
|
||||
self._port,
|
||||
compression=None,
|
||||
max_size=None,
|
||||
process_request=_health_check,
|
||||
) as server:
|
||||
await server.serve_forever()
|
||||
|
||||
async def _handler(self, websocket: _server.ServerConnection):
|
||||
logger.info(f"Connection from {websocket.remote_address} opened")
|
||||
packer = msgpack_numpy.Packer()
|
||||
|
||||
await websocket.send(packer.pack(self._metadata))
|
||||
|
||||
prev_total_time = None
|
||||
while True:
|
||||
try:
|
||||
start_time = time.monotonic()
|
||||
obs = msgpack_numpy.unpackb(await websocket.recv())
|
||||
|
||||
infer_time = time.monotonic()
|
||||
action = self._policy.infer(obs)
|
||||
infer_time = time.monotonic() - infer_time
|
||||
|
||||
action["server_timing"] = {
|
||||
"infer_ms": infer_time * 1000,
|
||||
}
|
||||
if prev_total_time is not None:
|
||||
# We can only record the last total time since we also want to include the send time.
|
||||
action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
|
||||
|
||||
await websocket.send(packer.pack(action))
|
||||
prev_total_time = time.monotonic() - start_time
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
logger.info(f"Connection from {websocket.remote_address} closed")
|
||||
break
|
||||
except Exception:
|
||||
await websocket.send(traceback.format_exc())
|
||||
await websocket.close(
|
||||
code=websockets.frames.CloseCode.INTERNAL_ERROR,
|
||||
reason="Internal server error. Traceback included in previous frame.",
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
|
||||
if request.path == "/healthz":
|
||||
return connection.respond(http.HTTPStatus.OK, "OK\n")
|
||||
# Continue with the normal request handling.
|
||||
return None
|
||||
@@ -0,0 +1,89 @@
|
||||
import contextlib
|
||||
import functools as ft
|
||||
import inspect
|
||||
from typing import TypeAlias, TypeVar, cast
|
||||
|
||||
import beartype
|
||||
import jax
|
||||
import jax._src.tree_util as private_tree_util
|
||||
import jax.core
|
||||
from jaxtyping import ArrayLike
|
||||
from jaxtyping import Bool # noqa: F401
|
||||
from jaxtyping import DTypeLike # noqa: F401
|
||||
from jaxtyping import Float
|
||||
from jaxtyping import Int # noqa: F401
|
||||
from jaxtyping import Key # noqa: F401
|
||||
from jaxtyping import Num # noqa: F401
|
||||
from jaxtyping import PyTree
|
||||
from jaxtyping import Real # noqa: F401
|
||||
from jaxtyping import UInt8 # noqa: F401
|
||||
from jaxtyping import config
|
||||
from jaxtyping import jaxtyped
|
||||
import jaxtyping._decorator
|
||||
import torch
|
||||
|
||||
# patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277.
|
||||
# the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`,
|
||||
# `jax.Sharding`, or even <object>) due to JAX tracing operations. this patch skips typechecking when the stack trace
|
||||
# contains `jax._src.tree_util`, which should only be the case during tree unflattening.
|
||||
_original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations # noqa: SLF001
|
||||
# Redefine Array to include both JAX arrays and PyTorch tensors
|
||||
Array = jax.Array | torch.Tensor
|
||||
|
||||
|
||||
def _check_dataclass_annotations(self, typechecker):
|
||||
if not any(
|
||||
frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"}
|
||||
for frame in inspect.stack()
|
||||
):
|
||||
return _original_check_dataclass_annotations(self, typechecker)
|
||||
return None
|
||||
|
||||
|
||||
jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001
|
||||
|
||||
KeyArrayLike: TypeAlias = jax.typing.ArrayLike
|
||||
Params: TypeAlias = PyTree[Float[ArrayLike, "..."]]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# runtime type-checking decorator
|
||||
def typecheck(t: T) -> T:
|
||||
return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable_typechecking():
|
||||
initial = config.jaxtyping_disable
|
||||
config.update("jaxtyping_disable", True) # noqa: FBT003
|
||||
yield
|
||||
config.update("jaxtyping_disable", initial)
|
||||
|
||||
|
||||
def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False):
|
||||
"""Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer
|
||||
error message than if `jax.tree.map` is naively used on PyTrees with different structures.
|
||||
"""
|
||||
|
||||
if errors := list(private_tree_util.equality_errors(expected, got)):
|
||||
raise ValueError(
|
||||
"PyTrees have different structure:\n"
|
||||
+ (
|
||||
"\n".join(
|
||||
f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n"
|
||||
for path, thing1, thing2, explanation in errors
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if check_shapes or check_dtypes:
|
||||
|
||||
def check(kp, x, y):
|
||||
if check_shapes and x.shape != y.shape:
|
||||
raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}")
|
||||
|
||||
if check_dtypes and x.dtype != y.dtype:
|
||||
raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}")
|
||||
|
||||
jax.tree_util.tree_map_with_path(check, expected, got)
|
||||
194
policy/openpi-InternData-A1/src/openpi/shared/download.py
Normal file
194
policy/openpi-InternData-A1/src/openpi/shared/download.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import filelock
|
||||
import fsspec
|
||||
import fsspec.generic
|
||||
import tqdm_loggable.auto as tqdm
|
||||
|
||||
# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
|
||||
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME"
|
||||
DEFAULT_CACHE_DIR = "~/.cache/openpi"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cache_dir() -> pathlib.Path:
|
||||
cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve()
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
_set_folder_permission(cache_dir)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
|
||||
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
|
||||
|
||||
If the local file already exists, it will be returned directly.
|
||||
|
||||
It is safe to call this function concurrently from multiple processes.
|
||||
See `get_cache_dir` for more details on the cache directory.
|
||||
|
||||
Args:
|
||||
url: URL to the file to download.
|
||||
force_download: If True, the file will be downloaded even if it already exists in the cache.
|
||||
**kwargs: Additional arguments to pass to fsspec.
|
||||
|
||||
Returns:
|
||||
Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute.
|
||||
"""
|
||||
# Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem.
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
|
||||
# Short circuit if this is a local path.
|
||||
if parsed.scheme == "":
|
||||
path = pathlib.Path(url)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found at {url}")
|
||||
return path.resolve()
|
||||
|
||||
cache_dir = get_cache_dir()
|
||||
|
||||
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
|
||||
local_path = local_path.resolve()
|
||||
|
||||
# Check if the cache should be invalidated.
|
||||
invalidate_cache = False
|
||||
if local_path.exists():
|
||||
if force_download or _should_invalidate_cache(cache_dir, local_path):
|
||||
invalidate_cache = True
|
||||
else:
|
||||
return local_path
|
||||
|
||||
try:
|
||||
lock_path = local_path.with_suffix(".lock")
|
||||
with filelock.FileLock(lock_path):
|
||||
# Ensure consistent permissions for the lock file.
|
||||
_ensure_permissions(lock_path)
|
||||
# First, remove the existing cache if it is expired.
|
||||
if invalidate_cache:
|
||||
logger.info(f"Removing expired cached entry: {local_path}")
|
||||
if local_path.is_dir():
|
||||
shutil.rmtree(local_path)
|
||||
else:
|
||||
local_path.unlink()
|
||||
|
||||
# Download the data to a local cache.
|
||||
logger.info(f"Downloading {url} to {local_path}")
|
||||
scratch_path = local_path.with_suffix(".partial")
|
||||
_download_fsspec(url, scratch_path, **kwargs)
|
||||
|
||||
shutil.move(scratch_path, local_path)
|
||||
_ensure_permissions(local_path)
|
||||
|
||||
except PermissionError as e:
|
||||
msg = (
|
||||
f"Local file permission error was encountered while downloading {url}. "
|
||||
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
|
||||
)
|
||||
raise PermissionError(msg) from e
|
||||
|
||||
return local_path
|
||||
|
||||
|
||||
def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
|
||||
"""Download a file from a remote filesystem to the local cache, and return the local path."""
|
||||
fs, _ = fsspec.core.url_to_fs(url, **kwargs)
|
||||
info = fs.info(url)
|
||||
# Folders are represented by 0-byte objects with a trailing forward slash.
|
||||
if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))):
|
||||
total_size = fs.du(url)
|
||||
else:
|
||||
total_size = info["size"]
|
||||
with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar:
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = executor.submit(fs.get, url, local_path, recursive=is_dir)
|
||||
while not future.done():
|
||||
current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file())
|
||||
pbar.update(current_size - pbar.n)
|
||||
time.sleep(1)
|
||||
pbar.update(total_size - pbar.n)
|
||||
|
||||
|
||||
def _set_permission(path: pathlib.Path, target_permission: int):
|
||||
"""chmod requires executable permission to be set, so we skip if the permission is already match with the target."""
|
||||
if path.stat().st_mode & target_permission == target_permission:
|
||||
logger.debug(f"Skipping {path} because it already has correct permissions")
|
||||
return
|
||||
path.chmod(target_permission)
|
||||
logger.debug(f"Set {path} to {target_permission}")
|
||||
|
||||
|
||||
def _set_folder_permission(folder_path: pathlib.Path) -> None:
|
||||
"""Set folder permission to be read, write and searchable."""
|
||||
_set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
|
||||
|
||||
def _ensure_permissions(path: pathlib.Path) -> None:
|
||||
"""Since we are sharing cache directory with containerized runtime as well as training script, we need to
|
||||
ensure that the cache directory has the correct permissions.
|
||||
"""
|
||||
|
||||
def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None:
|
||||
cache_dir = get_cache_dir()
|
||||
relative_path = path.relative_to(cache_dir)
|
||||
moving_path = cache_dir
|
||||
for part in relative_path.parts:
|
||||
_set_folder_permission(moving_path / part)
|
||||
moving_path = moving_path / part
|
||||
|
||||
def _set_file_permission(file_path: pathlib.Path) -> None:
|
||||
"""Set all files to be read & writable, if it is a script, keep it as a script."""
|
||||
file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH
|
||||
if file_path.stat().st_mode & 0o100:
|
||||
_set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
||||
else:
|
||||
_set_permission(file_path, file_rw)
|
||||
|
||||
_setup_folder_permission_between_cache_dir_and_path(path)
|
||||
for root, dirs, files in os.walk(str(path)):
|
||||
root_path = pathlib.Path(root)
|
||||
for file in files:
|
||||
file_path = root_path / file
|
||||
_set_file_permission(file_path)
|
||||
|
||||
for dir in dirs:
|
||||
dir_path = root_path / dir
|
||||
_set_folder_permission(dir_path)
|
||||
|
||||
|
||||
def _get_mtime(year: int, month: int, day: int) -> float:
|
||||
"""Get the mtime of a given date at midnight UTC."""
|
||||
date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)
|
||||
return time.mktime(date.timetuple())
|
||||
|
||||
|
||||
# Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format).
|
||||
# Partial matching will be used from top to bottom and the first match will be chosen.
|
||||
# Cached entries will be retained only if they are newer than the expiration timestamp.
|
||||
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
|
||||
re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17),
|
||||
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
|
||||
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
|
||||
}
|
||||
|
||||
|
||||
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
|
||||
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
|
||||
|
||||
assert local_path.exists(), f"File not found at {local_path}"
|
||||
|
||||
relative_path = str(local_path.relative_to(cache_dir))
|
||||
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
|
||||
if pattern.match(relative_path):
|
||||
# Remove if not newer than the expiration timestamp.
|
||||
return local_path.stat().st_mtime <= expire_time
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,54 @@
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
import openpi.shared.download as download
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_openpi_data_home(tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.mktemp("openpi_data")
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("OPENPI_DATA_HOME", str(temp_dir))
|
||||
yield
|
||||
|
||||
|
||||
def test_download_local(tmp_path: pathlib.Path):
|
||||
local_path = tmp_path / "local"
|
||||
local_path.touch()
|
||||
|
||||
result = download.maybe_download(str(local_path))
|
||||
assert result == local_path
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
download.maybe_download("bogus")
|
||||
|
||||
|
||||
def test_download_gs_dir():
|
||||
remote_path = "gs://openpi-assets/testdata/random"
|
||||
|
||||
local_path = download.maybe_download(remote_path)
|
||||
assert local_path.exists()
|
||||
|
||||
new_local_path = download.maybe_download(remote_path)
|
||||
assert new_local_path == local_path
|
||||
|
||||
|
||||
def test_download_gs():
|
||||
remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin"
|
||||
|
||||
local_path = download.maybe_download(remote_path)
|
||||
assert local_path.exists()
|
||||
|
||||
new_local_path = download.maybe_download(remote_path)
|
||||
assert new_local_path == local_path
|
||||
|
||||
|
||||
def test_download_fsspec():
|
||||
remote_path = "gs://big_vision/paligemma_tokenizer.model"
|
||||
|
||||
local_path = download.maybe_download(remote_path, gs={"token": "anon"})
|
||||
assert local_path.exists()
|
||||
|
||||
new_local_path = download.maybe_download(remote_path, gs={"token": "anon"})
|
||||
assert new_local_path == local_path
|
||||
126
policy/openpi-InternData-A1/src/openpi/shared/image_tools.py
Normal file
126
policy/openpi-InternData-A1/src/openpi/shared/image_tools.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import functools
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
||||
@at.typecheck
|
||||
def resize_with_pad(
|
||||
images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"],
|
||||
height: int,
|
||||
width: int,
|
||||
method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR,
|
||||
) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]:
|
||||
"""Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion
|
||||
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
||||
"""
|
||||
has_batch_dim = images.ndim == 4
|
||||
if not has_batch_dim:
|
||||
images = images[None] # type: ignore
|
||||
cur_height, cur_width = images.shape[1:3]
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_images = jax.image.resize(
|
||||
images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method
|
||||
)
|
||||
if images.dtype == jnp.uint8:
|
||||
# round from float back to uint8
|
||||
resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8)
|
||||
elif images.dtype == jnp.float32:
|
||||
resized_images = resized_images.clip(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
||||
pad_h1 = pad_h0 + remainder_h
|
||||
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
padded_images = jnp.pad(
|
||||
resized_images,
|
||||
((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)),
|
||||
constant_values=0 if images.dtype == jnp.uint8 else -1.0,
|
||||
)
|
||||
|
||||
if not has_batch_dim:
|
||||
padded_images = padded_images[0]
|
||||
return padded_images
|
||||
|
||||
|
||||
def resize_with_pad_torch(
|
||||
images: torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
mode: str = "bilinear",
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
||||
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
||||
|
||||
Args:
|
||||
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
||||
height: Target height
|
||||
width: Target width
|
||||
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
||||
|
||||
Returns:
|
||||
Resized and padded tensor with same shape format as input
|
||||
"""
|
||||
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
||||
if images.shape[-1] <= 4: # Assume channels-last format
|
||||
channels_last = True
|
||||
# Convert to channels-first for torch operations
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
||||
else:
|
||||
channels_last = False
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
|
||||
batch_size, channels, cur_height, cur_width = images.shape
|
||||
|
||||
# Calculate resize ratio
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
# Resize
|
||||
resized_images = F.interpolate(
|
||||
images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None
|
||||
)
|
||||
|
||||
# Handle dtype-specific clipping
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
# Calculate padding
|
||||
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
||||
pad_h1 = pad_h0 + remainder_h
|
||||
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
mode="constant",
|
||||
value=constant_value,
|
||||
)
|
||||
|
||||
# Convert back to original format if needed
|
||||
if channels_last:
|
||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
if batch_size == 1 and images.shape[0] == 1:
|
||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
||||
|
||||
return padded_images
|
||||
@@ -0,0 +1,37 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi.shared import image_tools
|
||||
|
||||
|
||||
def test_resize_with_pad_shapes():
|
||||
# Test case 1: Resize image with larger dimensions
|
||||
images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels)
|
||||
height = 20
|
||||
width = 20
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (2, height, width, 3)
|
||||
assert jnp.all(resized_images == 0)
|
||||
|
||||
# Test case 2: Resize image with smaller dimensions
|
||||
images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8)
|
||||
height = 15
|
||||
width = 15
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (3, height, width, 3)
|
||||
assert jnp.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with the same dimensions
|
||||
images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8)
|
||||
height = 50
|
||||
width = 50
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert jnp.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with odd-numbered padding
|
||||
images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8)
|
||||
height = 60
|
||||
width = 80
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert jnp.all(resized_images == 0)
|
||||
69
policy/openpi-InternData-A1/src/openpi/shared/nnx_utils.py
Normal file
69
policy/openpi-InternData-A1/src/openpi/shared/nnx_utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
import flax.nnx as nnx
|
||||
import jax
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]:
|
||||
"""A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process.
|
||||
|
||||
Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much
|
||||
more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module
|
||||
mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must
|
||||
traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details.
|
||||
|
||||
`module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by
|
||||
`module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was
|
||||
when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded
|
||||
after the method call completes.
|
||||
"""
|
||||
if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)):
|
||||
raise ValueError("module_jit must only be used on bound methods of nnx.Modules.")
|
||||
|
||||
graphdef, state = nnx.split(meth.__self__)
|
||||
|
||||
def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
module = nnx.merge(graphdef, state)
|
||||
return meth.__func__(module, *args, **kwargs)
|
||||
|
||||
jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs)
|
||||
|
||||
@functools.wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return jitted_fn(state, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PathRegex:
|
||||
"""NNX Filter that matches paths using a regex.
|
||||
|
||||
By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument.
|
||||
"""
|
||||
|
||||
pattern: str | re.Pattern
|
||||
sep: str = "/"
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.pattern, re.Pattern):
|
||||
object.__setattr__(self, "pattern", re.compile(self.pattern))
|
||||
|
||||
def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool:
|
||||
joined_path = self.sep.join(str(x) for x in path)
|
||||
assert isinstance(self.pattern, re.Pattern)
|
||||
return self.pattern.fullmatch(joined_path) is not None
|
||||
|
||||
|
||||
def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State:
|
||||
"""Apply a function to the leaves of the state that match the filter."""
|
||||
filtered_keys = set(state.filter(filter).flat_state())
|
||||
return state.map(lambda k, v: fn(v) if k in filtered_keys else v)
|
||||
199
policy/openpi-InternData-A1/src/openpi/shared/normalize.py
Normal file
199
policy/openpi-InternData-A1/src/openpi/shared/normalize.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import numpydantic
|
||||
import pydantic
|
||||
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class NormStats:
|
||||
mean: numpydantic.NDArray
|
||||
std: numpydantic.NDArray
|
||||
q01: numpydantic.NDArray | None = None # 1st quantile
|
||||
q99: numpydantic.NDArray | None = None # 99th quantile
|
||||
|
||||
|
||||
class RunningStats:
|
||||
"""Compute running statistics of a batch of vectors."""
|
||||
|
||||
def __init__(self):
|
||||
self._count = 0
|
||||
self._mean = None
|
||||
self._mean_of_squares = None
|
||||
self._min = None
|
||||
self._max = None
|
||||
self._histograms = None
|
||||
self._bin_edges = None
|
||||
self._num_quantile_bins = 5000 # for computing quantiles on the fly
|
||||
|
||||
def update(self, batch: np.ndarray) -> None:
|
||||
"""
|
||||
Update the running statistics with a batch of vectors.
|
||||
|
||||
Args:
|
||||
vectors (np.ndarray): An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
num_elements, vector_length = batch.shape
|
||||
if self._count == 0:
|
||||
self._mean = np.mean(batch, axis=0)
|
||||
self._mean_of_squares = np.mean(batch**2, axis=0)
|
||||
self._min = np.min(batch, axis=0)
|
||||
self._max = np.max(batch, axis=0)
|
||||
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
||||
self._bin_edges = [
|
||||
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
||||
for i in range(vector_length)
|
||||
]
|
||||
else:
|
||||
if vector_length != self._mean.size:
|
||||
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
||||
new_max = np.max(batch, axis=0)
|
||||
new_min = np.min(batch, axis=0)
|
||||
max_changed = np.any(new_max > self._max)
|
||||
min_changed = np.any(new_min < self._min)
|
||||
self._max = np.maximum(self._max, new_max)
|
||||
self._min = np.minimum(self._min, new_min)
|
||||
|
||||
if max_changed or min_changed:
|
||||
self._adjust_histograms()
|
||||
|
||||
self._count += num_elements
|
||||
|
||||
batch_mean = np.mean(batch, axis=0)
|
||||
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
||||
|
||||
# Update running mean and mean of squares.
|
||||
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
||||
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count)
|
||||
|
||||
self._update_histograms(batch)
|
||||
|
||||
def get_statistics(self) -> NormStats:
|
||||
"""
|
||||
Compute and return the statistics of the vectors processed so far.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the computed statistics.
|
||||
"""
|
||||
if self._count < 2:
|
||||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||||
|
||||
variance = self._mean_of_squares - self._mean**2
|
||||
stddev = np.sqrt(np.maximum(0, variance))
|
||||
q01, q99 = self._compute_quantiles([0.01, 0.99])
|
||||
return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99)
|
||||
|
||||
def _adjust_histograms(self):
|
||||
"""Adjust histograms when min or max changes."""
|
||||
for i in range(len(self._histograms)):
|
||||
old_edges = self._bin_edges[i]
|
||||
new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1)
|
||||
|
||||
# Redistribute the existing histogram counts to the new bins
|
||||
new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i])
|
||||
|
||||
self._histograms[i] = new_hist
|
||||
self._bin_edges[i] = new_edges
|
||||
|
||||
def _update_histograms(self, batch: np.ndarray) -> None:
|
||||
"""Update histograms with new vectors."""
|
||||
for i in range(batch.shape[1]):
|
||||
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
||||
self._histograms[i] += hist
|
||||
|
||||
def _compute_quantiles(self, quantiles):
|
||||
"""Compute quantiles based on histograms."""
|
||||
results = []
|
||||
for q in quantiles:
|
||||
target_count = q * self._count
|
||||
q_values = []
|
||||
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
||||
cumsum = np.cumsum(hist)
|
||||
idx = np.searchsorted(cumsum, target_count)
|
||||
q_values.append(edges[idx])
|
||||
results.append(np.array(q_values))
|
||||
return results
|
||||
|
||||
class OptimizedRunningStats:
|
||||
def __init__(self, num_quantile_bins=1000): # 减少bin数量
|
||||
self._count = 0
|
||||
self._sum = None
|
||||
self._sum_sq = None
|
||||
self._min = None
|
||||
self._max = None
|
||||
self._all_samples = [] # 用于存储采样数据
|
||||
self._sample_rate = 0.01 # 1%采样率
|
||||
self._num_quantile_bins = num_quantile_bins
|
||||
|
||||
def update(self, batch: np.ndarray) -> None:
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
num_elements = batch.shape[0]
|
||||
|
||||
# 更新基本统计量(向量化)
|
||||
if self._count == 0:
|
||||
self._sum = np.sum(batch, axis=0, dtype=np.float64)
|
||||
self._sum_sq = np.sum(batch**2, axis=0, dtype=np.float64)
|
||||
self._min = np.min(batch, axis=0)
|
||||
self._max = np.max(batch, axis=0)
|
||||
else:
|
||||
self._sum += np.sum(batch, axis=0, dtype=np.float64)
|
||||
self._sum_sq += np.sum(batch**2, axis=0, dtype=np.float64)
|
||||
self._min = np.minimum(self._min, np.min(batch, axis=0))
|
||||
self._max = np.maximum(self._max, np.max(batch, axis=0))
|
||||
|
||||
# 随机采样用于分位数计算(避免存储所有数据)
|
||||
if np.random.random() < self._sample_rate:
|
||||
sample_idx = np.random.randint(0, num_elements, size=min(100, num_elements))
|
||||
self._all_samples.append(batch[sample_idx])
|
||||
|
||||
self._count += num_elements
|
||||
|
||||
def get_statistics(self):
|
||||
if self._count < 2:
|
||||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||||
|
||||
# 计算均值和标准差
|
||||
mean = self._sum / self._count
|
||||
variance = (self._sum_sq / self._count) - mean**2
|
||||
stddev = np.sqrt(np.maximum(0, variance))
|
||||
|
||||
# 基于采样数据计算分位数
|
||||
if self._all_samples:
|
||||
all_sampled = np.concatenate(self._all_samples, axis=0)
|
||||
q01 = np.quantile(all_sampled, 0.01, axis=0)
|
||||
q99 = np.quantile(all_sampled, 0.99, axis=0)
|
||||
else:
|
||||
q01 = np.zeros_like(mean)
|
||||
q99 = np.zeros_like(mean)
|
||||
|
||||
return NormStats(mean=mean, std=stddev, q01=q01, q99=q99)
|
||||
|
||||
class _NormStatsDict(pydantic.BaseModel):
|
||||
norm_stats: dict[str, NormStats]
|
||||
|
||||
|
||||
def serialize_json(norm_stats: dict[str, NormStats]) -> str:
|
||||
"""Serialize the running statistics to a JSON string."""
|
||||
return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2)
|
||||
|
||||
|
||||
def deserialize_json(data: str) -> dict[str, NormStats]:
|
||||
"""Deserialize the running statistics from a JSON string."""
|
||||
return _NormStatsDict(**json.loads(data)).norm_stats
|
||||
|
||||
|
||||
def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None:
|
||||
"""Save the normalization stats to a directory."""
|
||||
path = pathlib.Path(directory) / "norm_stats.json"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(serialize_json(norm_stats))
|
||||
|
||||
|
||||
def load(directory: pathlib.Path | str) -> dict[str, NormStats]:
|
||||
"""Load the normalization stats from a directory."""
|
||||
path = pathlib.Path(directory) / "norm_stats.json"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Norm stats file not found at: {path}")
|
||||
return deserialize_json(path.read_text())
|
||||
@@ -0,0 +1,43 @@
|
||||
import numpy as np
|
||||
|
||||
import openpi.shared.normalize as normalize
|
||||
|
||||
|
||||
def test_normalize_update():
|
||||
arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3
|
||||
|
||||
stats = normalize.RunningStats()
|
||||
for i in range(len(arr)):
|
||||
stats.update(arr[i : i + 1]) # Update with one vector at a time
|
||||
results = stats.get_statistics()
|
||||
|
||||
assert np.allclose(results.mean, np.mean(arr, axis=0))
|
||||
assert np.allclose(results.std, np.std(arr, axis=0))
|
||||
|
||||
|
||||
def test_serialize_deserialize():
|
||||
stats = normalize.RunningStats()
|
||||
stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3
|
||||
|
||||
norm_stats = {"test": stats.get_statistics()}
|
||||
norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats))
|
||||
assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean)
|
||||
assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std)
|
||||
|
||||
|
||||
def test_multiple_batch_dimensions():
|
||||
# Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension
|
||||
batch_shape = (2, 3, 4)
|
||||
arr = np.random.rand(*batch_shape)
|
||||
|
||||
stats = normalize.RunningStats()
|
||||
stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4)
|
||||
results = stats.get_statistics()
|
||||
|
||||
# Flatten batch dimensions and compute expected stats
|
||||
flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4)
|
||||
expected_mean = np.mean(flattened, axis=0)
|
||||
expected_std = np.std(flattened, axis=0)
|
||||
|
||||
assert np.allclose(results.mean, expected_mean)
|
||||
assert np.allclose(results.std, expected_std)
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Compute normalization statistics for a config.
|
||||
|
||||
This script is used to compute the normalization statistics for a given config. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config assets directory.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.mixture_dataset as _mixture_dataset
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
from pdb import set_trace
|
||||
import openpi.training.weight_loaders as weight_loaders
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
# from openpi.training.config import MultiSimGenieDataConfig, MultiSimSplitAlohaDataConfig, MultiSimFrankaDataConfig, MultiLeRobotReala2dDataConfig, MultiLeRobotRealArxLift2DataConfig, MultiDataConfig, DataConfig, TrainConfig
|
||||
import logging
|
||||
from pdb import set_trace
|
||||
from typing import List
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_torch_dataloader(
|
||||
data_config: List[_config.DataConfig],
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
model_config: _model.BaseModelConfig,
|
||||
num_workers: int,
|
||||
max_frames: int | None = None,
|
||||
) -> tuple[_data_loader.Dataset, int]:
|
||||
# if data_config.repo_id is None:
|
||||
# raise ValueError("Data config must have a repo_id")
|
||||
# dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)
|
||||
dataset = _mixture_dataset.create_mixture_dataset_no_transform(data_config, action_horizon, model_config)
|
||||
# from pdb import set_trace; set_trace()
|
||||
dataset = _data_loader.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config[0][0].repack_transforms.inputs,
|
||||
*data_config[0][0].data_transforms.inputs,
|
||||
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
if max_frames is not None and max_frames < len(dataset):
|
||||
num_batches = max_frames // batch_size
|
||||
shuffle = True
|
||||
else:
|
||||
num_batches = len(dataset) // batch_size
|
||||
shuffle = False
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return data_loader, num_batches
|
||||
|
||||
def compute_norm_stats(config_name: str, max_frames: int | None = None):
|
||||
config = _config.get_config(config_name)
|
||||
data_configs_list = []
|
||||
for data_config_factory in config.data:
|
||||
data_configs = data_config_factory.create(config.model)
|
||||
logging.info(f"data_config: {data_configs}")
|
||||
data_configs_list.append(data_configs)
|
||||
print("done")
|
||||
data_loader, num_batches = create_torch_dataloader(
|
||||
data_configs_list, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
|
||||
)
|
||||
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
# stats = {key: normalize.OptimizedRunningStats() for key in keys} # 新的
|
||||
# set_trace()
|
||||
step_id = 0
|
||||
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
||||
step_id += 1
|
||||
for key in keys:
|
||||
stats[key].update(np.asarray(batch[key]))
|
||||
if step_id > 10000:
|
||||
break
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
print(norm_stats)
|
||||
return norm_stats
|
||||
|
||||
159
policy/openpi-InternData-A1/src/openpi/training/checkpoints.py
Normal file
159
policy/openpi-InternData-A1/src/openpi/training/checkpoints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures as futures
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
from etils import epath
|
||||
import jax
|
||||
import orbax.checkpoint as ocp
|
||||
import orbax.checkpoint.future as future
|
||||
|
||||
from openpi.shared import array_typing as at
|
||||
import openpi.shared.normalize as _normalize
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.utils as training_utils
|
||||
|
||||
|
||||
def initialize_checkpoint_dir(
|
||||
checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
|
||||
) -> tuple[ocp.CheckpointManager, bool]:
|
||||
checkpoint_dir = epath.Path(checkpoint_dir).resolve()
|
||||
resuming = False
|
||||
if checkpoint_dir.exists():
|
||||
if overwrite:
|
||||
checkpoint_dir.rmtree()
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
|
||||
elif resume:
|
||||
resuming = True
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
|
||||
"to indicate how to handle it."
|
||||
)
|
||||
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mngr = ocp.CheckpointManager(
|
||||
checkpoint_dir,
|
||||
item_handlers={
|
||||
"assets": CallbackHandler(),
|
||||
"train_state": ocp.PyTreeCheckpointHandler(),
|
||||
"params": ocp.PyTreeCheckpointHandler(),
|
||||
},
|
||||
options=ocp.CheckpointManagerOptions(
|
||||
max_to_keep=1,
|
||||
keep_period=keep_period,
|
||||
create=False,
|
||||
async_options=ocp.AsyncOptions(timeout_secs=7200),
|
||||
),
|
||||
)
|
||||
|
||||
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
|
||||
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
|
||||
# checkpoint, since it will fail.
|
||||
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
|
||||
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
|
||||
resuming = False
|
||||
|
||||
return mngr, resuming
|
||||
|
||||
|
||||
def save_state(
|
||||
checkpoint_manager: ocp.CheckpointManager,
|
||||
state: training_utils.TrainState,
|
||||
data_loader: _data_loader.DataLoader,
|
||||
step: int,
|
||||
):
|
||||
def save_assets(directory: epath.Path):
|
||||
# Save the normalization stats.
|
||||
data_config = data_loader.data_config()
|
||||
norm_stats = data_config.norm_stats
|
||||
if norm_stats is not None and data_config.asset_id is not None:
|
||||
_normalize.save(directory / data_config.asset_id, norm_stats)
|
||||
|
||||
# Split params that can be used for inference into a separate item.
|
||||
with at.disable_typechecking():
|
||||
train_state, params = _split_params(state)
|
||||
items = {
|
||||
"assets": save_assets,
|
||||
"train_state": train_state,
|
||||
"params": {"params": params},
|
||||
}
|
||||
checkpoint_manager.save(step, items)
|
||||
|
||||
|
||||
def restore_state(
|
||||
checkpoint_manager: ocp.CheckpointManager,
|
||||
state: training_utils.TrainState,
|
||||
data_loader: _data_loader.DataLoader,
|
||||
step: int | None = None,
|
||||
) -> training_utils.TrainState:
|
||||
del data_loader
|
||||
|
||||
with at.disable_typechecking():
|
||||
# Split params that can be used for inference into a separate item.
|
||||
train_state, params = _split_params(state)
|
||||
restored = checkpoint_manager.restore(
|
||||
step,
|
||||
items={
|
||||
"train_state": train_state,
|
||||
"params": {"params": params},
|
||||
},
|
||||
)
|
||||
return _merge_params(restored["train_state"], restored["params"])
|
||||
|
||||
|
||||
def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:
|
||||
norm_stats_dir = epath.Path(assets_dir) / asset_id
|
||||
norm_stats = _normalize.load(norm_stats_dir)
|
||||
logging.info(f"Loaded norm stats from {norm_stats_dir}")
|
||||
return norm_stats
|
||||
|
||||
|
||||
class Callback(Protocol):
|
||||
def __call__(self, directory: epath.Path) -> None: ...
|
||||
|
||||
|
||||
class CallbackHandler(ocp.AsyncCheckpointHandler):
|
||||
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
|
||||
|
||||
def save(self, directory: epath.Path, args: CallbackSave):
|
||||
if jax.process_index() == 0:
|
||||
args.callback(directory)
|
||||
|
||||
async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
|
||||
return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
|
||||
|
||||
def restore(self, *args, **kwargs):
|
||||
raise NotImplementedError("CallbackHandler does not support restore")
|
||||
|
||||
|
||||
@ocp.args.register_with_handler(CallbackHandler, for_save=True)
|
||||
@dataclasses.dataclass
|
||||
class CallbackSave(ocp.args.CheckpointArgs):
|
||||
callback: Callback
|
||||
|
||||
|
||||
@ocp.args.register_with_handler(CallbackHandler, for_restore=True)
|
||||
class CallbackRestore(ocp.args.CheckpointArgs): ...
|
||||
|
||||
|
||||
def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:
|
||||
if state.ema_params is not None:
|
||||
params = state.ema_params
|
||||
train_state = dataclasses.replace(state, ema_params=None)
|
||||
else:
|
||||
params = state.params
|
||||
train_state = dataclasses.replace(state, params={})
|
||||
return train_state, params
|
||||
|
||||
|
||||
def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:
|
||||
# Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.
|
||||
if train_state.params:
|
||||
return dataclasses.replace(train_state, ema_params=params["params"])
|
||||
return dataclasses.replace(train_state, params=params["params"])
|
||||
1904
policy/openpi-InternData-A1/src/openpi/training/config.py
Normal file
1904
policy/openpi-InternData-A1/src/openpi/training/config.py
Normal file
File diff suppressed because it is too large
Load Diff
721
policy/openpi-InternData-A1/src/openpi/training/data_loader.py
Normal file
721
policy/openpi-InternData-A1/src/openpi/training/data_loader.py
Normal file
@@ -0,0 +1,721 @@
|
||||
from collections.abc import Iterator, Sequence
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import typing
|
||||
from typing import Literal, Protocol, SupportsIndex, TypeVar, Dict
|
||||
import sys
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import dataloader
|
||||
from torch.multiprocessing import reductions
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
import openpi.shared.normalize as normalize
|
||||
default_collate_func = dataloader.default_collate
|
||||
import psutil
|
||||
|
||||
def default_collate_override(batch):
|
||||
dataloader._use_shared_memory = False
|
||||
return default_collate_func(batch)
|
||||
|
||||
setattr(dataloader, 'default_collate', default_collate_override)
|
||||
|
||||
for t in torch._storage_classes:
|
||||
if sys.version_info[0] == 2:
|
||||
if t in ForkingPickler.dispatch:
|
||||
del ForkingPickler.dispatch[t]
|
||||
else:
|
||||
if t in ForkingPickler._extra_reducers:
|
||||
del ForkingPickler._extra_reducers[t]
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.training.config as _config
|
||||
from openpi.training.droid_rlds_dataset import DroidRldsDataset
|
||||
import openpi.transforms as _transforms
|
||||
from openpi.training.mixture_dataset import create_mixture_dataset
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
import copy
|
||||
from memory_profiler import profile
|
||||
from pdb import set_trace
|
||||
|
||||
class Dataset(Protocol[T_co]):
|
||||
"""Interface for a dataset with random access."""
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
||||
|
||||
|
||||
class IterableDataset(Protocol[T_co]):
|
||||
"""Interface for an iterable dataset."""
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
||||
|
||||
|
||||
class DataLoader(Protocol[T_co]):
|
||||
"""Interface for a data loader."""
|
||||
|
||||
def data_config(self) -> _config.DataConfig:
|
||||
"""Get the data config for this data loader."""
|
||||
raise NotImplementedError("Subclasses of DataLoader should implement data_config.")
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
raise NotImplementedError("Subclasses of DataLoader should implement __iter__.")
|
||||
|
||||
|
||||
class TransformedDataset(Dataset[T_co]):
|
||||
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
|
||||
self._dataset = dataset
|
||||
self._transform = _transforms.compose(transforms)
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
return self._transform(self._dataset[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
|
||||
|
||||
class IterableTransformedDataset(IterableDataset[T_co]):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: IterableDataset,
|
||||
transforms: Sequence[_transforms.DataTransformFn],
|
||||
*,
|
||||
is_batched: bool = False,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._transform = _transforms.compose(transforms)
|
||||
self._is_batched = is_batched
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self._dataset:
|
||||
if self._is_batched:
|
||||
# Transforms are designed to be applied to individual samples. So we need to split the batch into
|
||||
# individual samples and apply the transform to each sample individually.
|
||||
batch_size = next(v.shape[0] for v in sample.values())
|
||||
|
||||
# Split batch into individual samples using tree_map
|
||||
individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023
|
||||
|
||||
# Transform each sample
|
||||
transformed = [self._transform(s) for s in individual_samples]
|
||||
|
||||
# Recombine batch with tree_map
|
||||
yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed)
|
||||
else:
|
||||
yield self._transform(sample)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
|
||||
|
||||
class FakeDataset(Dataset):
|
||||
def __init__(self, model_config: _model.BaseModelConfig, num_samples: int):
|
||||
self._num_samples = num_samples
|
||||
self._observation_spec, self._action_spec = model_config.inputs_spec()
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> dict:
|
||||
rng = jax.random.key(index.__index__())
|
||||
|
||||
def make_from_spec(spec: jax.ShapeDtypeStruct):
|
||||
nonlocal rng
|
||||
rng, data_rng = jax.random.split(rng)
|
||||
# Remove the batch dimension.
|
||||
shape = spec.shape[1:]
|
||||
if spec.dtype == jnp.float32:
|
||||
return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0)
|
||||
if spec.dtype == jnp.int32:
|
||||
return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048)
|
||||
return jnp.zeros(shape=shape, dtype=spec.dtype)
|
||||
|
||||
observation = jax.tree.map(make_from_spec, self._observation_spec)
|
||||
action = jax.tree.map(make_from_spec, self._action_spec)
|
||||
|
||||
return {
|
||||
**observation.to_dict(),
|
||||
"actions": action,
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_samples
|
||||
|
||||
|
||||
def create_torch_dataset(
|
||||
data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
|
||||
) -> Dataset:
|
||||
"""Create a dataset for training."""
|
||||
repo_id = data_config.repo_id
|
||||
if repo_id is None:
|
||||
raise ValueError("Repo ID is not set. Cannot create dataset.")
|
||||
if repo_id == "fake":
|
||||
return FakeDataset(model_config, num_samples=1024)
|
||||
|
||||
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
|
||||
dataset = lerobot_dataset.LeRobotDataset(
|
||||
data_config.repo_id,
|
||||
delta_timestamps={
|
||||
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
|
||||
},
|
||||
)
|
||||
|
||||
if data_config.prompt_from_task:
|
||||
dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def create_rlds_dataset(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
shuffle: bool = False,
|
||||
) -> Dataset:
|
||||
# At the moment, we only support DROID for RLDS datasets.
|
||||
return DroidRldsDataset(
|
||||
data_dir=data_config.rlds_data_dir,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
action_chunk_size=action_horizon,
|
||||
action_space=data_config.action_space,
|
||||
filter_dict_path=data_config.filter_dict_path,
|
||||
)
|
||||
|
||||
|
||||
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:
|
||||
"""Transform the dataset by applying the data transforms."""
|
||||
norm_stats = {}
|
||||
if data_config.repo_id != "fake" and not skip_norm_stats:
|
||||
if data_config.norm_stats is None:
|
||||
raise ValueError(
|
||||
"Normalization stats not found. "
|
||||
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
|
||||
)
|
||||
norm_stats = data_config.norm_stats
|
||||
|
||||
return TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def transform_iterable_dataset(
|
||||
dataset: IterableDataset,
|
||||
data_config: _config.DataConfig,
|
||||
*,
|
||||
skip_norm_stats: bool = False,
|
||||
is_batched: bool = False,
|
||||
) -> IterableDataset:
|
||||
"""Transform the dataset by applying the data transforms."""
|
||||
norm_stats = {}
|
||||
if data_config.repo_id != "fake" and not skip_norm_stats:
|
||||
if data_config.norm_stats is None:
|
||||
raise ValueError(
|
||||
"Normalization stats not found. "
|
||||
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
|
||||
)
|
||||
norm_stats = data_config.norm_stats
|
||||
|
||||
return IterableTransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
is_batched=is_batched,
|
||||
)
|
||||
|
||||
|
||||
def create_data_loader(
|
||||
config: _config.TrainConfig,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
framework: Literal["jax", "pytorch"] = "jax",
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
config: The training configuration.
|
||||
sharding: The sharding to use for the data loader (JAX only).
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
framework: The framework to use ("jax" or "pytorch").
|
||||
"""
|
||||
data_config = config.data.create(config.assets_dirs, config.model)
|
||||
logging.info(f"data_config: {data_config}")
|
||||
|
||||
if data_config.rlds_data_dir is not None:
|
||||
return create_rlds_data_loader(
|
||||
data_config,
|
||||
action_horizon=config.model.action_horizon,
|
||||
batch_size=config.batch_size,
|
||||
sharding=sharding,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
skip_norm_stats=skip_norm_stats,
|
||||
framework=framework,
|
||||
)
|
||||
return create_torch_data_loader(
|
||||
data_config,
|
||||
model_config=config.model,
|
||||
action_horizon=config.model.action_horizon,
|
||||
batch_size=config.batch_size,
|
||||
sharding=sharding,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
num_workers=config.num_workers,
|
||||
seed=config.seed,
|
||||
skip_norm_stats=skip_norm_stats,
|
||||
framework=framework,
|
||||
)
|
||||
|
||||
def create_data_loader_multi(
|
||||
config: _config.TrainConfig,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
framework: Literal["jax", "pytorch"] = "jax",
|
||||
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
config: The training configuration.
|
||||
sharding: The sharding to use for the data loader (JAX only).
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
framework: The framework to use ("jax" or "pytorch").
|
||||
"""
|
||||
data_configs_list = []
|
||||
for data_config_factory in config.data:
|
||||
data_configs = data_config_factory.create(config.model, global_norm_stats)
|
||||
logging.info(f"data_config: {data_configs}")
|
||||
data_configs_list.append(data_configs)
|
||||
|
||||
return create_torch_data_loader_multi(
|
||||
data_configs_list,
|
||||
model_config=config.model,
|
||||
action_horizon=config.model.action_horizon,
|
||||
batch_size=config.batch_size,
|
||||
sharding=sharding,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
num_workers=config.num_workers,
|
||||
seed=config.seed,
|
||||
skip_norm_stats=skip_norm_stats,
|
||||
framework=framework,
|
||||
global_norm_stats=global_norm_stats,
|
||||
)
|
||||
|
||||
|
||||
def create_torch_data_loader(
|
||||
data_config: _config.DataConfig,
|
||||
model_config: _model.BaseModelConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
num_workers: int = 0,
|
||||
seed: int = 0,
|
||||
framework: str = "jax",
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
data_config: The data configuration.
|
||||
action_horizon: The action horizon.
|
||||
batch_size: The batch size.
|
||||
sharding: The sharding to use for the data loader. If None, the data loader will
|
||||
use a single device sharding.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return. If the number exceeds the
|
||||
number of batches in the dataset, the data loader will loop over the dataset.
|
||||
If not provided, will iterate over the dataset indefinitely.
|
||||
num_workers: The number of worker processes to use. If zero, the data loader will
|
||||
execute in the main process.
|
||||
seed: The seed to use for shuffling the data.
|
||||
"""
|
||||
dataset = create_torch_dataset(data_config, action_horizon, model_config)
|
||||
dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)
|
||||
|
||||
# Use TorchDataLoader for both frameworks
|
||||
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
|
||||
# For JAX, divide by process count
|
||||
sampler = None
|
||||
if framework == "pytorch":
|
||||
if torch.distributed.is_initialized():
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=torch.distributed.get_world_size(),
|
||||
rank=torch.distributed.get_rank(),
|
||||
shuffle=shuffle,
|
||||
drop_last=True,
|
||||
)
|
||||
local_batch_size = batch_size // torch.distributed.get_world_size()
|
||||
else:
|
||||
local_batch_size = batch_size
|
||||
else:
|
||||
local_batch_size = batch_size // jax.process_count()
|
||||
if jax.process_count() > 1:
|
||||
sampler = JaxProcessDistributedSampler(
|
||||
dataset_size=len(dataset),
|
||||
num_replicas=jax.process_count(),
|
||||
rank=jax.process_index(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
)
|
||||
logging.info(f"local_batch_size: {local_batch_size}")
|
||||
data_loader = TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=local_batch_size,
|
||||
sharding=None if framework == "pytorch" else sharding,
|
||||
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
||||
sampler=sampler,
|
||||
num_batches=num_batches,
|
||||
num_workers=num_workers,
|
||||
seed=seed,
|
||||
framework=framework,
|
||||
)
|
||||
|
||||
return DataLoaderImpl(data_config, data_loader)
|
||||
|
||||
def create_torch_data_loader_multi(
|
||||
data_configs_list: list[_config.DataConfig],
|
||||
model_config: _model.BaseModelConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
num_workers: int = 0,
|
||||
seed: int = 0,
|
||||
framework: str = "jax",
|
||||
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
data_config: The data configuration.
|
||||
action_horizon: The action horizon.
|
||||
batch_size: The batch size.
|
||||
sharding: The sharding to use for the data loader. If None, the data loader will
|
||||
use a single device sharding.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return. If the number exceeds the
|
||||
number of batches in the dataset, the data loader will loop over the dataset.
|
||||
If not provided, will iterate over the dataset indefinitely.
|
||||
num_workers: The number of worker processes to use. If zero, the data loader will
|
||||
execute in the main process.
|
||||
seed: The seed to use for shuffling the data.
|
||||
"""
|
||||
dataset = create_mixture_dataset(data_configs_list, action_horizon, model_config)
|
||||
# Use TorchDataLoader for both frameworks
|
||||
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
|
||||
# For JAX, divide by process count
|
||||
sampler = None
|
||||
if framework == "pytorch":
|
||||
if torch.distributed.is_initialized():
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=torch.distributed.get_world_size(),
|
||||
rank=torch.distributed.get_rank(),
|
||||
shuffle=shuffle,
|
||||
drop_last=True,
|
||||
)
|
||||
local_batch_size = batch_size // torch.distributed.get_world_size()
|
||||
else:
|
||||
local_batch_size = batch_size
|
||||
else:
|
||||
local_batch_size = batch_size // jax.process_count()
|
||||
if jax.process_count() > 1:
|
||||
sampler = JaxProcessDistributedSampler(
|
||||
dataset_size=len(dataset),
|
||||
num_replicas=jax.process_count(),
|
||||
rank=jax.process_index(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
)
|
||||
logging.info(f"local_batch_size: {local_batch_size}")
|
||||
data_loader = TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=local_batch_size,
|
||||
sharding=None if framework == "pytorch" else sharding,
|
||||
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
||||
sampler=sampler,
|
||||
num_batches=num_batches,
|
||||
num_workers=num_workers,
|
||||
seed=seed,
|
||||
framework=framework,
|
||||
)
|
||||
|
||||
return DataLoaderImpl(data_configs_list[0][0], data_loader)
|
||||
|
||||
|
||||
def create_rlds_data_loader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
framework: str = "jax",
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create an RLDS data loader for training.
|
||||
|
||||
Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md
|
||||
|
||||
Args:
|
||||
data_config: The data configuration.
|
||||
action_horizon: The action horizon.
|
||||
batch_size: The batch size.
|
||||
sharding: The sharding to use for the data loader. If None, the data loader will
|
||||
use a single device sharding.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return. If the number exceeds the
|
||||
number of batches in the dataset, the data loader will loop over the dataset.
|
||||
If not provided, will iterate over the dataset indefinitely.
|
||||
"""
|
||||
if framework == "pytorch":
|
||||
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
|
||||
dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
|
||||
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
|
||||
|
||||
data_loader = RLDSDataLoader(
|
||||
dataset,
|
||||
sharding=sharding,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
|
||||
return DataLoaderImpl(data_config, data_loader)
|
||||
|
||||
|
||||
class JaxProcessDistributedSampler(torch.utils.data.Sampler[int]):
|
||||
"""Simple sampler to split dataset indices across JAX processes.
|
||||
|
||||
Each process sees a disjoint slice of indices using striding by num_replicas.
|
||||
Shuffling (if enabled) is deterministic via the provided seed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_size: int,
|
||||
*,
|
||||
num_replicas: int,
|
||||
rank: int,
|
||||
shuffle: bool,
|
||||
seed: int,
|
||||
) -> None:
|
||||
self._dataset_size = max(0, dataset_size)
|
||||
self._num_replicas = max(1, num_replicas)
|
||||
self._rank = max(0, rank)
|
||||
self._shuffle = shuffle
|
||||
self._seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(self._dataset_size))
|
||||
if self._shuffle and self._dataset_size > 0:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed)
|
||||
indices = torch.randperm(self._dataset_size, generator=g).tolist()
|
||||
# Strided split across processes; drop remainder for balance
|
||||
indices = indices[self._rank :: self._num_replicas]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
# Match strided selection length
|
||||
return (self._dataset_size + self._num_replicas - 1) // self._num_replicas
|
||||
|
||||
# @profile
|
||||
class TorchDataLoader:
|
||||
"""Torch data loader implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
local_batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
shuffle: bool = False,
|
||||
sampler: torch.utils.data.Sampler | None = None,
|
||||
num_batches: int | None = None,
|
||||
num_workers: int = 0,
|
||||
seed: int = 0,
|
||||
framework: str = "jax",
|
||||
):
|
||||
"""Create a PyTorch data loader.
|
||||
|
||||
Args:
|
||||
dataset: The dataset to load.
|
||||
local_batch_size: The local batch size for each process.
|
||||
sharding: The sharding to use for the data loader.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: If provided, determines the number of returned batches. If the
|
||||
number is larger than the number of batches in the dataset, the data loader
|
||||
will loop over the dataset. If not provided, will iterate over the dataset
|
||||
indefinitely.
|
||||
num_workers: The number of worker processes to use. If zero, the data loader will
|
||||
execute in the main process.
|
||||
seed: The seed to use for shuffling the data.
|
||||
"""
|
||||
if len(dataset) < local_batch_size:
|
||||
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
|
||||
|
||||
# Store sharding - None for PyTorch, JAX sharding for JAX
|
||||
self._sharding = sharding
|
||||
if sharding is None and framework == "jax":
|
||||
# Use data parallel sharding by default for JAX only.
|
||||
self._sharding = jax.sharding.NamedSharding(
|
||||
jax.sharding.Mesh(jax.devices(), ("B",)),
|
||||
jax.sharding.PartitionSpec("B"),
|
||||
)
|
||||
self._num_batches = num_batches
|
||||
|
||||
mp_context = None
|
||||
if num_workers > 0:
|
||||
mp_context = multiprocessing.get_context("spawn")
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
self._data_loader = torch.utils.data.DataLoader(
|
||||
typing.cast(torch.utils.data.Dataset, dataset),
|
||||
batch_size=local_batch_size,
|
||||
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
multiprocessing_context=mp_context,
|
||||
persistent_workers=num_workers > 0,
|
||||
collate_fn=_collate_fn,
|
||||
worker_init_fn=_worker_init_fn,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def torch_loader(self) -> torch.utils.data.DataLoader:
|
||||
return self._data_loader
|
||||
|
||||
@profile
|
||||
def __iter__(self):
|
||||
num_items = 0
|
||||
while True:
|
||||
data_iter = iter(self._data_loader)
|
||||
while True:
|
||||
if self._num_batches is not None and num_items >= self._num_batches:
|
||||
return
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
break # We've exhausted the dataset. Create a new iterator and start over.
|
||||
num_items += 1
|
||||
# For JAX, convert to sharded arrays; for PyTorch, return torch tensors
|
||||
if self._sharding is not None:
|
||||
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
|
||||
else:
|
||||
yield jax.tree.map(torch.as_tensor, batch)
|
||||
|
||||
|
||||
def _collate_fn(items):
|
||||
"""Collate the batch elements into batched numpy arrays."""
|
||||
# Make sure to convert to numpy arrays before stacking since some of the incoming elements
|
||||
# may be JAX arrays.
|
||||
return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)
|
||||
|
||||
|
||||
def _worker_init_fn(worker_id: int) -> None:
|
||||
"""Tell JAX inside the worker process not to preallocate the GPU memory."""
|
||||
# NOTE: This is called after jax is imported inside the worker process. This
|
||||
# means that this approach will not work for selecting the backend.
|
||||
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
||||
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
||||
|
||||
|
||||
class RLDSDataLoader:
|
||||
"""Shallow wrapper around the DROID data loader to make it compatible with openpi.
|
||||
|
||||
All batching already happens in the DROID dataset, so we don't need to do anything here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DroidRldsDataset,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
num_batches: int | None = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._num_batches = num_batches
|
||||
|
||||
if jax.process_count() > 1:
|
||||
raise NotImplementedError("Data loading with multiple processes is not supported.")
|
||||
|
||||
if sharding is None:
|
||||
# Use data parallel sharding by default.
|
||||
sharding = jax.sharding.NamedSharding(
|
||||
jax.sharding.Mesh(jax.devices(), ("B",)),
|
||||
jax.sharding.PartitionSpec("B"),
|
||||
)
|
||||
|
||||
self._sharding = sharding
|
||||
self._num_batches = num_batches
|
||||
|
||||
def __iter__(self):
|
||||
num_items = 0
|
||||
while True:
|
||||
data_iter = iter(self._dataset)
|
||||
while True:
|
||||
if self._num_batches is not None and num_items >= self._num_batches:
|
||||
return
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
break # We've exhausted the dataset. Create a new iterator and start over.
|
||||
num_items += 1
|
||||
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
|
||||
|
||||
|
||||
class DataLoaderImpl(DataLoader):
|
||||
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
|
||||
self._data_config = data_config
|
||||
self._data_loader = data_loader
|
||||
|
||||
def data_config(self) -> _config.DataConfig:
|
||||
return self._data_config
|
||||
|
||||
def __iter__(self):
|
||||
for batch in self._data_loader:
|
||||
yield _model.Observation.from_dict(batch), batch["actions"]
|
||||
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
RLDS-based data loader for DROID.
|
||||
While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.
|
||||
Thus, we provide a data loader example here that uses the RLDS data format.
|
||||
The data loader also applies a few DROID-specific data filters / transformations.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from enum import auto
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
|
||||
import openpi.shared.download as download
|
||||
|
||||
|
||||
class DroidActionSpace(Enum):
|
||||
"""Action space for DROID dataset."""
|
||||
|
||||
JOINT_POSITION = auto()
|
||||
JOINT_VELOCITY = auto()
|
||||
|
||||
|
||||
class DroidRldsDataset:
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
batch_size: int,
|
||||
*, # Force keyword-only arguments
|
||||
shuffle: bool = True,
|
||||
action_chunk_size: int = 16,
|
||||
# We default to joint position actions, since they allow policy evaluation in simulation.
|
||||
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
|
||||
max_loaded_steps_per_episode: int = 100,
|
||||
# Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
|
||||
shuffle_buffer_size: int = 250_000,
|
||||
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
|
||||
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
|
||||
filter_dict_path=None, # Path to json file with indices to sample during training
|
||||
):
|
||||
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
|
||||
import dlimp as dl
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
|
||||
builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1")
|
||||
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
|
||||
|
||||
# Filter out any unsuccessful trajectories -- we use the file name to check this
|
||||
dataset = dataset.filter(
|
||||
lambda traj: tf.strings.regex_full_match(
|
||||
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
|
||||
)
|
||||
)
|
||||
|
||||
# # Repeat dataset so we never run out of data.
|
||||
dataset = dataset.repeat()
|
||||
|
||||
# Load the filter dictionary if provided.
|
||||
# The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample
|
||||
# (e.g.,
|
||||
# {
|
||||
# "<episode key>": [[0, 100], [200, 300]]
|
||||
# }
|
||||
# means keep frames 0-99 and 200-299).
|
||||
if filter_dict_path is not None:
|
||||
cached_filter_dict_path = download.maybe_download(filter_dict_path)
|
||||
with Path(cached_filter_dict_path).open("r") as f:
|
||||
filter_dict = json.load(f)
|
||||
|
||||
logging.info(f"Using filter dictionary with {len(filter_dict)} episodes")
|
||||
|
||||
keys_tensor = []
|
||||
values_tensor = []
|
||||
|
||||
for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."):
|
||||
for start, end in ranges:
|
||||
for t in range(start, end):
|
||||
frame_key = f"{episode_key}--{t}"
|
||||
keys_tensor.append(frame_key)
|
||||
values_tensor.append(True)
|
||||
self.filter_table = tf.lookup.StaticHashTable(
|
||||
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False
|
||||
)
|
||||
logging.info("Filter hash table initialized")
|
||||
else:
|
||||
self.filter_table = tf.lookup.StaticHashTable(
|
||||
tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True
|
||||
)
|
||||
|
||||
def restructure(traj):
|
||||
"""Reformat observation and action keys, sample language instruction."""
|
||||
# Important: we use joint *position* action space -- easier to simulate!
|
||||
actions = tf.concat(
|
||||
(
|
||||
(
|
||||
traj["action_dict"]["joint_position"]
|
||||
if action_space == DroidActionSpace.JOINT_POSITION
|
||||
else traj["action_dict"]["joint_velocity"]
|
||||
),
|
||||
traj["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
|
||||
# Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
|
||||
exterior_img = tf.cond(
|
||||
tf.random.uniform(shape=[]) > 0.5,
|
||||
lambda: traj["observation"]["exterior_image_1_left"],
|
||||
lambda: traj["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
wrist_img = traj["observation"]["wrist_image_left"]
|
||||
# Randomly sample one of the three language instructions
|
||||
instruction = tf.random.shuffle(
|
||||
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
|
||||
)[0]
|
||||
|
||||
traj_len = tf.shape(traj["action"])[0]
|
||||
indices = tf.as_string(tf.range(traj_len))
|
||||
|
||||
# Data filtering:
|
||||
# Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,
|
||||
# and each step's time step index. This will index into the filter hash table, and if it returns true,
|
||||
# then the frame passes the filter.
|
||||
step_id = (
|
||||
traj["traj_metadata"]["episode_metadata"]["recording_folderpath"]
|
||||
+ "--"
|
||||
+ traj["traj_metadata"]["episode_metadata"]["file_path"]
|
||||
+ "--"
|
||||
+ indices
|
||||
)
|
||||
passes_filter = self.filter_table.lookup(step_id)
|
||||
|
||||
return {
|
||||
"actions": actions,
|
||||
"observation": {
|
||||
"image": exterior_img,
|
||||
"wrist_image": wrist_img,
|
||||
"joint_position": traj["observation"]["joint_position"],
|
||||
"gripper_position": traj["observation"]["gripper_position"],
|
||||
},
|
||||
"prompt": instruction,
|
||||
"step_id": step_id,
|
||||
"passes_filter": passes_filter,
|
||||
}
|
||||
|
||||
dataset = dataset.traj_map(restructure, num_parallel_calls)
|
||||
|
||||
def chunk_actions(traj):
|
||||
"""Splits episode into action chunks."""
|
||||
traj_len = tf.shape(traj["actions"])[0]
|
||||
|
||||
# For each step in the trajectory, construct indices for the next n actions
|
||||
action_chunk_indices = tf.broadcast_to(
|
||||
tf.range(action_chunk_size)[None],
|
||||
[traj_len, action_chunk_size],
|
||||
) + tf.broadcast_to(
|
||||
tf.range(traj_len)[:, None],
|
||||
[traj_len, action_chunk_size],
|
||||
)
|
||||
|
||||
# Cap to length of the sequence --> final chunks will repeat the last action
|
||||
# This makes sense, since we are using absolute joint + gripper position actions
|
||||
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
|
||||
|
||||
# Gather the actions for each chunk
|
||||
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
|
||||
return traj
|
||||
|
||||
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
|
||||
|
||||
# Flatten: map from trajectory dataset to dataset of individual action chunks
|
||||
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
|
||||
|
||||
# Filter data that doesn't pass the filter
|
||||
def filter_from_dict(frame):
|
||||
return frame["passes_filter"]
|
||||
|
||||
dataset = dataset.filter(filter_from_dict)
|
||||
|
||||
# Remove "passes_filter" key from output
|
||||
def remove_passes_filter(frame):
|
||||
frame.pop("passes_filter")
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(remove_passes_filter)
|
||||
|
||||
# Decode images: RLDS saves encoded images, only decode now for efficiency
|
||||
def decode_images(traj):
|
||||
traj["observation"]["image"] = tf.io.decode_image(
|
||||
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
|
||||
)
|
||||
traj["observation"]["wrist_image"] = tf.io.decode_image(
|
||||
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
|
||||
)
|
||||
return traj
|
||||
|
||||
dataset = dataset.frame_map(decode_images, num_parallel_calls)
|
||||
|
||||
# Shuffle, batch
|
||||
dataset = dataset.shuffle(shuffle_buffer_size)
|
||||
dataset = dataset.batch(batch_size)
|
||||
# Note =>> Seems to reduce memory usage without affecting speed?
|
||||
dataset = dataset.with_ram_budget(1)
|
||||
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.dataset.as_numpy_iterator()
|
||||
|
||||
def __len__(self):
|
||||
# This is the approximate number of samples in DROID after filtering.
|
||||
# Easier to hardcode than to iterate through the dataset and compute it.
|
||||
return 20_000_000
|
||||
@@ -0,0 +1,116 @@
|
||||
"""RoboArena baseline policy configs."""
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
import openpi.models.pi0_fast as pi0_fast
|
||||
import openpi.models.tokenizer as _tokenizer
|
||||
import openpi.policies.droid_policy as droid_policy
|
||||
import openpi.transforms as _transforms
|
||||
|
||||
ModelType: TypeAlias = _model.ModelType
|
||||
|
||||
|
||||
def get_roboarena_configs():
|
||||
# Import here to avoid circular imports.
|
||||
from openpi.training.config import AssetsConfig
|
||||
from openpi.training.config import DataConfig
|
||||
from openpi.training.config import SimpleDataConfig
|
||||
from openpi.training.config import TrainConfig
|
||||
|
||||
return [
|
||||
#
|
||||
# RoboArena DROID baseline inference configs.
|
||||
#
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
||||
name="paligemma_binning_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
max_token_len=400,
|
||||
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
||||
name="paligemma_fast_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
||||
name="paligemma_fast_specialist_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
fast_model_tokenizer=_tokenizer.FASTTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FSQ tokenizer.
|
||||
name="paligemma_vq_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
fast_model_tokenizer=_tokenizer.FSQTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
||||
name="paligemma_diffusion_droid",
|
||||
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,703 @@
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import SupportsIndex, Sequence, List, Dict, Any, Tuple, Optional, Union, TypeVar, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
|
||||
import openpi.transforms as _transforms
|
||||
from pdb import set_trace
|
||||
import logging
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
import openpi.training.config as _config
|
||||
import openpi.shared.normalize as normalize
|
||||
|
||||
def detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions: list[str] = ["action"],
|
||||
gripper_dim: int = -1,
|
||||
threshold_method: str = "std_multiplier",
|
||||
threshold_multiplier: float = 2.0,
|
||||
min_threshold: float = 0.001,
|
||||
max_threshold: float = 1.0,
|
||||
plot_gripper_changes: bool = False,
|
||||
):
|
||||
"""
|
||||
Detect the step of gripper change. Only work for the self-collected dataset.
|
||||
Modifies the dataset in place by adding 'gripper_change_step_idx' attribute.
|
||||
This version uses a sliding window of size 4 centered around non_zero_idx,
|
||||
including the indices and removing duplicates.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
select_actions: List of action keys to process
|
||||
gripper_dim: Dimension index for gripper in the action vector
|
||||
threshold_method: Method to calculate threshold ('std_multiplier', 'percentile', 'absolute')
|
||||
threshold_multiplier: Multiplier for std-based threshold
|
||||
min_threshold: Minimum threshold value to avoid too sensitive detection
|
||||
max_threshold: Maximum threshold value to avoid missing large changes
|
||||
plot_gripper_changes: Whether to plot gripper changes visualization
|
||||
"""
|
||||
episode_lengths = [ep_dict["length"] for ep_dict in dataset.meta.episodes.values()]
|
||||
cumulative_lengths = np.cumsum(episode_lengths)
|
||||
|
||||
all_window_indices = set() # Use a set for automatic deduplication
|
||||
|
||||
for action_key in select_actions:
|
||||
action_values = dataset.hf_dataset[action_key]
|
||||
|
||||
delta_action = np.diff(action_values, axis=0)
|
||||
|
||||
# Handle episode boundaries
|
||||
for end_idx in cumulative_lengths[:-1]:
|
||||
if end_idx - 1 < len(delta_action) and end_idx - 2 >= 0:
|
||||
delta_action[end_idx - 1] = delta_action[end_idx - 2]
|
||||
elif end_idx - 1 < len(delta_action):
|
||||
delta_action[end_idx - 1] = 0
|
||||
|
||||
if delta_action.ndim == 1:
|
||||
delta_action = delta_action[:, np.newaxis]
|
||||
|
||||
assert delta_action.ndim == 2
|
||||
|
||||
# Extract gripper delta values
|
||||
gripper_delta = delta_action[:, gripper_dim]
|
||||
|
||||
# Calculate threshold based on statistical properties
|
||||
if threshold_method == "std_multiplier":
|
||||
# Use standard deviation to filter out small tremors
|
||||
std_val = np.std(gripper_delta)
|
||||
threshold = threshold_multiplier * std_val
|
||||
elif threshold_method == "percentile":
|
||||
# Use percentile-based threshold (e.g., 90th percentile)
|
||||
threshold = np.percentile(np.abs(gripper_delta), 85)
|
||||
elif threshold_method == "absolute":
|
||||
# Use absolute threshold
|
||||
threshold = threshold_multiplier
|
||||
else:
|
||||
raise ValueError(f"Unknown threshold_method: {threshold_method}")
|
||||
|
||||
# Clamp threshold to reasonable bounds
|
||||
threshold = np.clip(threshold, min_threshold, max_threshold)
|
||||
|
||||
# Find indices where gripper change exceeds threshold
|
||||
significant_change_idx = np.where(np.abs(gripper_delta) > threshold)[0]
|
||||
|
||||
cur_window_indices = set()
|
||||
for idx in significant_change_idx:
|
||||
# Create a sliding window of size 4 centered around idx.
|
||||
# The window should include [idx-2, idx-1, idx, idx+1].
|
||||
# This means starting 2 before and ending 1 after.
|
||||
window_start = idx - 2
|
||||
window_end = idx + 1
|
||||
|
||||
# Generate indices for the current window and ensure they are non-negative
|
||||
# and within the bounds of the original action_values length.
|
||||
# The maximum index possible is len(action_values) - 1.
|
||||
# Since delta_action is len(action_values) - 1, the index refers to
|
||||
# the step *before* the change. So the max index we want is effectively
|
||||
# len(action_values) - 1, which corresponds to the last valid step index.
|
||||
# If the original index is `i`, delta_action[i] corresponds to the change
|
||||
# from step `i` to `i+1`. We want to include step `i` and its neighbors.
|
||||
# The maximum index for steps is `len(action_values) - 1`.
|
||||
# So, the window indices should not exceed `len(action_values) - 1`.
|
||||
max_possible_idx = len(action_values) - 1
|
||||
|
||||
# Ensure indices are within valid range [0, max_possible_idx]
|
||||
current_window_indices = np.arange(
|
||||
max(0, window_start), min(max_possible_idx + 1, window_end + 1)
|
||||
)
|
||||
for w_idx in current_window_indices:
|
||||
cur_window_indices.add(w_idx)
|
||||
all_window_indices.add(w_idx)
|
||||
|
||||
if plot_gripper_changes:
|
||||
num_episodes_to_plot = 5
|
||||
end_index_for_plot = cumulative_lengths[num_episodes_to_plot - 1] - 1
|
||||
delta_action_to_plot = delta_action[:end_index_for_plot]
|
||||
|
||||
# Filter gripper_change_step_idx
|
||||
gripper_change_step_idx = np.array(sorted(list(cur_window_indices))).astype(np.int32)
|
||||
gripper_change_step_idx_to_plot = gripper_change_step_idx[gripper_change_step_idx < end_index_for_plot]
|
||||
|
||||
plot_gripper_changes_in_subplots(
|
||||
delta_action_to_plot,
|
||||
gripper_change_step_idx_to_plot,
|
||||
episode_lengths,
|
||||
num_episodes_to_plot,
|
||||
gripper_dim,
|
||||
f"{action_key}_gripper_change"
|
||||
)
|
||||
|
||||
# Convert the set to a numpy array and sort it
|
||||
gripper_change_step_idx = np.array(sorted(list(all_window_indices))).astype(np.int32)
|
||||
|
||||
print(f"Total unique gripper change steps: {len(gripper_change_step_idx)}, Total steps: {len(action_values)}")
|
||||
|
||||
dataset.gripper_change_step_idx = gripper_change_step_idx
|
||||
# set_trace()
|
||||
|
||||
return dataset
|
||||
|
||||
class Dataset(Protocol[T_co]):
|
||||
"""Interface for a dataset with random access."""
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
||||
|
||||
class TransformedDataset(Dataset[T_co]):
|
||||
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
|
||||
self._dataset = dataset
|
||||
self._transform = _transforms.compose(transforms)
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
return self._transform(self._dataset[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
|
||||
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig) -> Dataset:
|
||||
"""Transform the dataset by applying the data transforms."""
|
||||
norm_stats = {}
|
||||
norm_stats = data_config.norm_stats
|
||||
|
||||
return TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
)
|
||||
|
||||
class MixtureDataset(Dataset):
|
||||
"""
|
||||
A composite dataset that combines multiple datasets, allowing for weighted sampling
|
||||
and specific handling based on training stage (e.g., pretrain, finetune) and
|
||||
gripper change detection for augmentation.
|
||||
|
||||
This dataset flattens all eligible samples from its constituent datasets and assigns
|
||||
sampling weights based on configuration and heuristics (e.g., `gripper_aug_ratio`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: Sequence[Dataset],
|
||||
datasets_name: Sequence[str],
|
||||
datasets_meta: Sequence[LeRobotDatasetMetadata],
|
||||
datasets_weights: Dict[str, float] = None,
|
||||
gripper_aug_ratio: float = 1.0,
|
||||
shuffle: bool = True,
|
||||
):
|
||||
"""
|
||||
Initializes the MixtureDataset.
|
||||
|
||||
Args:
|
||||
datasets (Sequence[Dataset]): A list of `Dataset` objects to be combined.
|
||||
datasets_name (Sequence[str]): A list of names corresponding to each dataset in `datasets`.
|
||||
datasets_meta (Sequence[LeRobotDatasetMetadata]): Metadata for each dataset,
|
||||
typically containing `num_episodes`, `num_frames`, `fps`, and `num_indices`.
|
||||
datasets_weights (Dict[str, float], optional): A dictionary mapping dataset names
|
||||
to their base sampling weights. If None, equal weights are assumed.
|
||||
is_eval (bool): If True, the dataset is configured for evaluation, potentially
|
||||
limiting the number of episodes and disabling shuffling for reproducibility.
|
||||
num_eval_episodes (int, optional): The number of episodes to select for evaluation.
|
||||
Only used if `is_eval` is True.
|
||||
stage (str): The current training stage (e.g., "stage1_pretrain_wm").
|
||||
This affects how indices are sampled from the underlying datasets.
|
||||
gripper_aug_ratio (float): A multiplier applied to the weights of samples
|
||||
that contain a detected gripper change. Useful for augmenting rare events.
|
||||
shuffle (bool): If True, the flat sample map and sampling weights are shuffled
|
||||
after initial creation. Ignored if `is_eval` is True.
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.datasets_name = datasets_name
|
||||
self.meta = datasets_meta
|
||||
# Extract total number of episodes and frames for each dataset from metadata.
|
||||
self.num_episodes = [meta.info['total_episodes'] for meta in datasets_meta]
|
||||
self.num_frames = [meta.info['total_frames'] for meta in datasets_meta]
|
||||
|
||||
|
||||
# Compute the flattened list of (dataset_idx, sample_idx) pairs.
|
||||
# This involves sampling indices based on the stage and dataset type.
|
||||
self._compute_len(False)
|
||||
# Assign normalized sampling weights to each sample in the flattened map.
|
||||
self._get_weights(datasets_weights, gripper_aug_ratio)
|
||||
|
||||
# For training, ensure the sample map and weights are consistent.
|
||||
if len(self.flat_sample_map) != len(self.sample_weights):
|
||||
raise ValueError(
|
||||
f"Mismatch in flat sample map length ({len(self.flat_sample_map)}) "
|
||||
f"and sample weights length ({len(self.sample_weights)})."
|
||||
)
|
||||
if shuffle:
|
||||
# Shuffle both the sample map and weights in the same order for training.
|
||||
# This ensures random access to samples while maintaining their assigned probabilities.
|
||||
indices = np.random.permutation(len(self.flat_sample_map))
|
||||
self.flat_sample_map = [self.flat_sample_map[i] for i in indices]
|
||||
self.sample_weights = self.sample_weights[indices]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the total number of samples in the mixture dataset (after flattening and selection).
|
||||
This length represents the effective size of the dataset for iteration.
|
||||
"""
|
||||
return len(self.flat_sample_map)
|
||||
|
||||
def __getitem__(self, index: SupportsIndex):
|
||||
"""
|
||||
Retrieves a specific sample from one of the underlying datasets based on the
|
||||
flattened sample map.
|
||||
|
||||
Args:
|
||||
index (SupportsIndex): The index in the flattened `flat_sample_map` (0 to `len(self) - 1`).
|
||||
|
||||
Returns:
|
||||
Tuple[int, Any]: A tuple containing the original dataset index and the
|
||||
sample data (dictionary) from that dataset.
|
||||
|
||||
Raises:
|
||||
IndexError: If the provided index is out of bounds for the dataset.
|
||||
"""
|
||||
if not (0 <= index < len(self.flat_sample_map)):
|
||||
raise IndexError(f"Index {index} is out of bounds for the dataset (size: {len(self.flat_sample_map)}).")
|
||||
|
||||
# Retrieve the original dataset index and sample index from the flattened map.
|
||||
dataset_idx, sample_idx = self.flat_sample_map[index]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
def _compute_len(self, is_eval: bool = False):
|
||||
"""
|
||||
Pre-computes and stores `all_sample_indices`, a list of episode indices sampled
|
||||
from each constituent dataset. This method prepares the data for `_create_flat_sample_map`.
|
||||
|
||||
Args:
|
||||
is_eval (bool): Flag indicating if indices are being computed for an evaluation dataset.
|
||||
"""
|
||||
self.all_sample_indices: List[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = []
|
||||
|
||||
for i, (ds, meta) in enumerate(zip(self.datasets, self.meta)):
|
||||
# Access the underlying LeRobotDataset or MultiLeRobotDataset, bypassing TransformedDataset wrapper.
|
||||
actual_ds = ds._dataset if isinstance(ds, TransformedDataset) else ds
|
||||
|
||||
# Determine the number of indices to sample for this dataset based on the current stage.
|
||||
# "stage1" typically uses a limited number of indices (`num_indices`), while other stages
|
||||
# might use all available data or a different strategy.
|
||||
num_indices = None
|
||||
|
||||
if isinstance(actual_ds, MultiLeRobotDataset):
|
||||
# For MultiLeRobotDataset, iterate through its sub-datasets to get indices.
|
||||
indices_list_for_multi_ds = []
|
||||
for sub_ds in actual_ds._datasets:
|
||||
_from = sub_ds.episode_data_index["from"]
|
||||
_to = sub_ds.episode_data_index["to"]
|
||||
indices = self._sample_indices(
|
||||
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
|
||||
)
|
||||
indices_list_for_multi_ds.append(indices)
|
||||
self.all_sample_indices.append(indices_list_for_multi_ds)
|
||||
elif isinstance(actual_ds, LeRobotDataset):
|
||||
# For a single LeRobotDataset.
|
||||
_from = actual_ds.episode_data_index["from"]
|
||||
_to = actual_ds.episode_data_index["to"]
|
||||
indices = self._sample_indices(
|
||||
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
|
||||
)
|
||||
self.all_sample_indices.append(indices)
|
||||
else:
|
||||
raise TypeError(f"Unsupported dataset type: {type(actual_ds)}. "
|
||||
"Expected `LeRobotDataset` or `MultiLeRobotDataset`.")
|
||||
|
||||
# After collecting all sampled episode indices, flatten them into `flat_sample_map`.
|
||||
self.flat_sample_map = self._create_flat_sample_map()
|
||||
|
||||
def _create_flat_sample_map(self) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Converts the potentially nested structure of `self.all_sample_indices` (which can be
|
||||
lists of lists of tensors, or lists of tensors) into a flat list of
|
||||
`(original_dataset_index, sample_index_within_original_dataset)` tuples.
|
||||
|
||||
This flattened map is then used by `__getitem__` to efficiently retrieve samples.
|
||||
"""
|
||||
flat_map = []
|
||||
for dataset_idx, sample_group in enumerate(self.all_sample_indices):
|
||||
# Case 1: `MultiLeRobotDataset` where `sample_group` is `List[List[torch.Tensor]]`
|
||||
if isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], list):
|
||||
for sub_group in sample_group: # Iterate through sub-datasets' index lists
|
||||
for tensor_of_indices in sub_group: # Iterate through tensors of indices for episodes
|
||||
for i in range(tensor_of_indices.numel()):
|
||||
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
|
||||
# Case 2: `LeRobotDataset` where `sample_group` is `List[torch.Tensor]`
|
||||
elif isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], torch.Tensor):
|
||||
for tensor_of_indices in sample_group:
|
||||
for i in range(tensor_of_indices.numel()):
|
||||
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
|
||||
# Case 3: A rare case where `sample_group` might be a single `torch.Tensor` directly
|
||||
elif isinstance(sample_group, torch.Tensor):
|
||||
for i in range(sample_group.numel()):
|
||||
flat_map.append((dataset_idx, sample_group[i].item()))
|
||||
return flat_map
|
||||
|
||||
def _sample_indices(
|
||||
self,
|
||||
start: List[int],
|
||||
end: List[int],
|
||||
num_frames: Optional[int],
|
||||
random_pad: bool = False,
|
||||
is_eval: bool = False,
|
||||
dataset_name: str = None, # Added for potential future stage-specific logic
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Samples indices for episodes based on the current stage and dataset-specific rules.
|
||||
This function is called per episode to determine which frames to include.
|
||||
|
||||
Args:
|
||||
start (List[int]): List of starting frame indices for each episode.
|
||||
end (List[int]): List of ending frame indices for each episode.
|
||||
num_frames (Optional[int]): The target number of frames to sample per episode.
|
||||
This is primarily used for "stage1" where sampling
|
||||
a fixed number of frames per episode might be desired.
|
||||
random_pad (bool): If True, and `frame_count < target_frames`, shorter episodes
|
||||
will be padded with randomly selected indices from themselves.
|
||||
is_eval (bool): If True, adjusts indices for evaluation (e.g., shifting by 1 for stage1
|
||||
to ensure predicted frames are not identical to observed frames).
|
||||
dataset_name (str): The name of the dataset (for debugging or future dataset-specific sampling rules).
|
||||
|
||||
Returns:
|
||||
List[torch.Tensor]: A list of PyTorch tensors, where each tensor contains the
|
||||
sampled frame indices for a single episode.
|
||||
"""
|
||||
all_indices_for_episodes = []
|
||||
for _start, _end in zip(start, end):
|
||||
frame_count = _end - _start # Total frames available in this episode.
|
||||
target_frames = frame_count
|
||||
if frame_count >= target_frames:
|
||||
# If enough frames are available, linearly space the indices to sample `target_frames`.
|
||||
indices = torch.linspace(_start, _end - 1, steps=target_frames).long()
|
||||
else:
|
||||
# If fewer frames than `target_frames` are available.
|
||||
if random_pad:
|
||||
# Pad the existing frames with randomly chosen duplicates from the episode.
|
||||
pad_size = target_frames - frame_count
|
||||
indices = torch.arange(_start, _end) # All available original indices
|
||||
# Randomly sample `pad_size` indices from the existing ones.
|
||||
pad_indices = indices[torch.randint(0, frame_count, (pad_size,))]
|
||||
indices = torch.cat([indices, pad_indices]) # Combine original and padded indices
|
||||
indices = indices[torch.randperm(target_frames)] # Randomly permute to mix original and padded.
|
||||
else:
|
||||
# If not padding, simply use all available frames.
|
||||
indices = torch.arange(_start, _end)
|
||||
|
||||
all_indices_for_episodes.append(indices)
|
||||
|
||||
return all_indices_for_episodes
|
||||
|
||||
def _get_weights(self, datasets_weights: Dict[str, float], aug_ratio: float = 1.0):
|
||||
"""
|
||||
Assigns normalized sampling weights to each individual sample in the flattened map.
|
||||
Weights are adjusted based on base dataset weights and `gripper_aug_ratio` for
|
||||
samples that have a detected gripper change.
|
||||
|
||||
Args:
|
||||
datasets_weights (Dict[str, float]): A dictionary mapping dataset names to their
|
||||
base sampling weights. If a dataset name is
|
||||
not found, a default weight of 1.0 is used.
|
||||
aug_ratio (float): The augmentation ratio (multiplier) to apply to the base weight
|
||||
for samples where a gripper change is detected.
|
||||
"""
|
||||
self.sample_weights: List[float] = []
|
||||
self.datasets_weight_map: Dict[str, float] = {}
|
||||
|
||||
if datasets_weights is None:
|
||||
num_datasets = len(self.datasets_name)
|
||||
datasets_weights = {name: 1.0 / num_datasets for name in self.datasets_name}
|
||||
|
||||
for idx, ds_name in enumerate(self.datasets_name):
|
||||
# Access the underlying dataset to get gripper change information.
|
||||
# It might be wrapped in a TransformedDataset, so we unwrap it.
|
||||
current_base_dataset = self.datasets[idx]._dataset if isinstance(self.datasets[idx], TransformedDataset) else self.datasets[idx]
|
||||
base_weight = datasets_weights.get(ds_name, 1.0) # Get base weight for this dataset
|
||||
|
||||
individual_weights_for_ds: List[float] = []
|
||||
|
||||
# Logic to retrieve `gripper_change_step_idx` and assign weights.
|
||||
if isinstance(current_base_dataset, MultiLeRobotDataset):
|
||||
# For MultiLeRobotDataset, iterate through its sub-datasets.
|
||||
for idj, sub_ds in enumerate(current_base_dataset._datasets):
|
||||
gripper_change_step_idx = getattr(sub_ds, 'gripper_change_step_idx', None)
|
||||
if gripper_change_step_idx is not None:
|
||||
sampled_indices_sub_ds = self.all_sample_indices[idx][idj]
|
||||
for tensor_of_indices in sampled_indices_sub_ds:
|
||||
for step_idx in tensor_of_indices.tolist():
|
||||
if step_idx in gripper_change_step_idx:
|
||||
individual_weights_for_ds.append(base_weight * aug_ratio)
|
||||
else:
|
||||
individual_weights_for_ds.append(base_weight)
|
||||
elif isinstance(current_base_dataset, LeRobotDataset):
|
||||
# For a single LeRobotDataset.
|
||||
gripper_change_step_idx = getattr(current_base_dataset, 'gripper_change_step_idx', None)
|
||||
if gripper_change_step_idx is not None:
|
||||
sampled_indices_ds = self.all_sample_indices[idx]
|
||||
for tensor_of_indices in sampled_indices_ds:
|
||||
for step_idx in tensor_of_indices.tolist():
|
||||
if step_idx in gripper_change_step_idx:
|
||||
individual_weights_for_ds.append(base_weight * aug_ratio)
|
||||
else:
|
||||
individual_weights_for_ds.append(base_weight)
|
||||
if gripper_change_step_idx is None:
|
||||
print(f"Warning: Gripper change detection not fully supported for dataset type {type(current_base_dataset)}. "
|
||||
"Assigning uniform weights based on `base_weight` for this dataset.")
|
||||
num_samples_for_ds_in_flat_map = sum(1 for map_ds_idx, _ in self.flat_sample_map if map_ds_idx == idx)
|
||||
individual_weights_for_ds.extend([base_weight] * num_samples_for_ds_in_flat_map)
|
||||
|
||||
# Accumulate individual weights for all samples and for the dataset's total.
|
||||
self.sample_weights.extend(individual_weights_for_ds)
|
||||
self.datasets_weight_map[ds_name] = self.datasets_weight_map.get(ds_name, 0.0) + sum(individual_weights_for_ds)
|
||||
|
||||
# Final normalization of all individual sample weights across the entire mixture dataset.
|
||||
total_sum_of_all_individual_weights = sum(self.sample_weights)
|
||||
if total_sum_of_all_individual_weights > 0:
|
||||
self.sample_weights = np.array(self.sample_weights, dtype=np.float32)
|
||||
self.sample_weights = self.sample_weights / total_sum_of_all_individual_weights
|
||||
else:
|
||||
self.sample_weights = np.array([], dtype=np.float32)
|
||||
|
||||
# Normalize the `datasets_weight_map` to reflect the effective proportion of each dataset
|
||||
# in the final sampling distribution.
|
||||
if total_sum_of_all_individual_weights > 0:
|
||||
for k in self.datasets_weight_map:
|
||||
self.datasets_weight_map[k] /= total_sum_of_all_individual_weights
|
||||
else:
|
||||
self.datasets_weight_map = {k: 0.0 for k in self.datasets_weight_map} # All weights become zero.
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Returns a formatted string representation of the MixtureDataset,
|
||||
showing the effective sampling weights and dataset lengths.
|
||||
"""
|
||||
# Define ANSI escape codes for colored and bold text.
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
CYAN = "\033[96m"
|
||||
YELLOW = "\033[93m"
|
||||
GREEN = "\033[92m"
|
||||
MAGENTA = "\033[95m"
|
||||
|
||||
# Determine the maximum key length for consistent formatting.
|
||||
max_key_len = max(len(k) for k in self.datasets_weight_map.keys()) + 2 if self.datasets_weight_map else 20
|
||||
|
||||
# Build the lines of the string representation.
|
||||
lines = [
|
||||
f"{BOLD}{MAGENTA}######################################### 👈 Dataset Weight Map: ########################################{RESET}"
|
||||
]
|
||||
|
||||
# Add individual dataset information: name, number of samples, and effective weight.
|
||||
for idx, (name, weight) in enumerate(self.datasets_weight_map.items()):
|
||||
# Use `len(self.datasets[idx])` to get the number of samples in each transformed dataset.
|
||||
# Formatting to 2 decimal places for weight and 0 for sample count.
|
||||
lines.append(f"{CYAN}{name:<{max_key_len}} : {len(self.datasets[idx]):>18.0f} ({weight*100:>.2f}%){RESET}")
|
||||
|
||||
# Add a separator line.
|
||||
separator_length = len(lines[0]) - len(BOLD) - len(MAGENTA) - len(RESET) + 1
|
||||
lines.append("-" * separator_length)
|
||||
|
||||
# Add total episodes summary.
|
||||
lines.append(f"{CYAN}{'Total Episodes':<{max_key_len}}{RESET} : {YELLOW}{sum(self.num_episodes):>18.0f}{RESET}")
|
||||
|
||||
# Add the closing border, matching the length of the separator.
|
||||
lines.append(f"{BOLD}{MAGENTA}{'#' * separator_length}{RESET}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def create_mixture_dataset(
|
||||
data_configs_list,
|
||||
action_horizon,
|
||||
model_config,
|
||||
):
|
||||
all_datasets = []
|
||||
all_datasets_name = []
|
||||
all_datasets_meta = []
|
||||
all_datasets_weight = {}
|
||||
|
||||
for ds_configs in data_configs_list:
|
||||
for ds_config in ds_configs:
|
||||
repo_dir = ds_config.repo_dir
|
||||
task_id = ds_config.task_id
|
||||
subtask_id = ds_config.subtask_id
|
||||
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
|
||||
|
||||
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
|
||||
episodes = list(dataset_meta.episodes_stats.keys())
|
||||
if ds_config.data_ratio < 1.0:
|
||||
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
|
||||
logging.info(f"sub_length: {sub_length}")
|
||||
indices = np.random.choice(len(episodes), sub_length, replace=False)
|
||||
episodes = [episodes[i] for i in indices]
|
||||
print(f"downsample ratio: {ds_config.downsample_ratio}")
|
||||
dataset = LeRobotDataset(
|
||||
episodes=episodes,
|
||||
repo_id=root_path,
|
||||
root=root_path,
|
||||
delta_timestamps={
|
||||
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
|
||||
},
|
||||
)
|
||||
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
|
||||
gripper_aug_config = ds_config.gripper_aug_config
|
||||
dataset = detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions=gripper_aug_config["gripper_action_keys"],
|
||||
gripper_dim=gripper_aug_config["gripper_dim"],
|
||||
threshold_method=gripper_aug_config["gripper_threshold_method"],
|
||||
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
|
||||
min_threshold=gripper_aug_config["gripper_min_threshold"],
|
||||
max_threshold=gripper_aug_config["gripper_max_threshold"],
|
||||
)
|
||||
|
||||
dataset = transform_dataset(dataset, ds_config)
|
||||
dataset_name = root_path
|
||||
dataset_weight = ds_config.weight
|
||||
|
||||
all_datasets.append(dataset)
|
||||
all_datasets_name.append(dataset_name)
|
||||
all_datasets_meta.append(dataset_meta)
|
||||
all_datasets_weight[dataset_name] = dataset_weight
|
||||
|
||||
mixture_dataset = MixtureDataset(
|
||||
all_datasets,
|
||||
all_datasets_name,
|
||||
all_datasets_meta,
|
||||
all_datasets_weight,
|
||||
gripper_aug_ratio=10.0,
|
||||
)
|
||||
return mixture_dataset
|
||||
|
||||
def create_mixture_dataset_no_transform(
|
||||
data_configs_list,
|
||||
action_horizon,
|
||||
model_config
|
||||
):
|
||||
all_datasets = []
|
||||
all_datasets_name = []
|
||||
all_datasets_meta = []
|
||||
all_datasets_weight = {}
|
||||
|
||||
for ds_configs in data_configs_list:
|
||||
for ds_config in ds_configs:
|
||||
repo_dir = ds_config.repo_dir
|
||||
task_id = ds_config.task_id
|
||||
subtask_id = ds_config.subtask_id
|
||||
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
|
||||
|
||||
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
|
||||
episodes = list(dataset_meta.episodes_stats.keys())
|
||||
if ds_config.data_ratio < 1.0:
|
||||
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
|
||||
episodes = episodes[:sub_length]
|
||||
dataset = LeRobotDataset(
|
||||
episodes=episodes,
|
||||
repo_id=root_path,
|
||||
root=root_path,
|
||||
delta_timestamps={
|
||||
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
|
||||
},
|
||||
)
|
||||
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
|
||||
gripper_aug_config = ds_config.gripper_aug_config
|
||||
dataset = detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions=gripper_aug_config["gripper_action_keys"],
|
||||
gripper_dim=gripper_aug_config["gripper_dim"],
|
||||
threshold_method=gripper_aug_config["gripper_threshold_method"],
|
||||
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
|
||||
min_threshold=gripper_aug_config["gripper_min_threshold"],
|
||||
max_threshold=gripper_aug_config["gripper_max_threshold"],
|
||||
)
|
||||
|
||||
dataset_name = root_path
|
||||
dataset_weight = ds_config.weight
|
||||
|
||||
all_datasets.append(dataset)
|
||||
all_datasets_name.append(dataset_name)
|
||||
all_datasets_meta.append(dataset_meta)
|
||||
all_datasets_weight[dataset_name] = dataset_weight
|
||||
|
||||
mixture_dataset = MixtureDataset(
|
||||
all_datasets,
|
||||
all_datasets_name,
|
||||
all_datasets_meta,
|
||||
all_datasets_weight,
|
||||
gripper_aug_ratio=10.0,
|
||||
)
|
||||
return mixture_dataset
|
||||
|
||||
def create_mixture_dataset_calculate_norm_stats(
|
||||
data_configs_list,
|
||||
action_horizon,
|
||||
model_config
|
||||
):
|
||||
all_datasets = []
|
||||
all_datasets_name = []
|
||||
all_datasets_meta = []
|
||||
all_datasets_weight = {}
|
||||
|
||||
for ds_config in data_configs_list:
|
||||
repo_dir = ds_config.repo_dir
|
||||
task_id = ds_config.task_id
|
||||
subtask_id = ds_config.subtask_id
|
||||
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
|
||||
|
||||
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
|
||||
episodes = list(dataset_meta.episodes_stats.keys())
|
||||
if ds_config.data_ratio < 1.0:
|
||||
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
|
||||
episodes = episodes[:sub_length]
|
||||
dataset = LeRobotDataset(
|
||||
episodes=episodes,
|
||||
repo_id=root_path,
|
||||
root=root_path,
|
||||
delta_timestamps={
|
||||
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
|
||||
},
|
||||
load_video=False,
|
||||
|
||||
)
|
||||
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
|
||||
gripper_aug_config = ds_config.gripper_aug_config
|
||||
dataset = detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions=gripper_aug_config["gripper_action_keys"],
|
||||
gripper_dim=gripper_aug_config["gripper_dim"],
|
||||
threshold_method=gripper_aug_config["gripper_threshold_method"],
|
||||
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
|
||||
min_threshold=gripper_aug_config["gripper_min_threshold"],
|
||||
max_threshold=gripper_aug_config["gripper_max_threshold"],
|
||||
)
|
||||
|
||||
dataset_name = root_path
|
||||
dataset_weight = ds_config.weight
|
||||
|
||||
all_datasets.append(dataset)
|
||||
all_datasets_name.append(dataset_name)
|
||||
all_datasets_meta.append(dataset_meta)
|
||||
all_datasets_weight[dataset_name] = dataset_weight
|
||||
|
||||
mixture_dataset = MixtureDataset(
|
||||
all_datasets,
|
||||
all_datasets_name,
|
||||
all_datasets_meta,
|
||||
all_datasets_weight,
|
||||
gripper_aug_ratio=10.0,
|
||||
)
|
||||
return mixture_dataset
|
||||
|
||||
123
policy/openpi-InternData-A1/src/openpi/training/optimizer.py
Normal file
123
policy/openpi-InternData-A1/src/openpi/training/optimizer.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import dataclasses
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LRScheduleConfig(Protocol):
|
||||
def create(self) -> optax.Schedule: ...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CosineDecaySchedule(LRScheduleConfig):
|
||||
"""Cosine decay schedule with warmup."""
|
||||
|
||||
warmup_steps: int = 1_000
|
||||
peak_lr: float = 2.5e-5
|
||||
decay_steps: int = 30_000
|
||||
decay_lr: float = 2.5e-6
|
||||
|
||||
def create(self) -> optax.Schedule:
|
||||
return optax.warmup_cosine_decay_schedule(
|
||||
init_value=self.peak_lr / (self.warmup_steps + 1),
|
||||
peak_value=self.peak_lr,
|
||||
warmup_steps=self.warmup_steps,
|
||||
decay_steps=self.decay_steps,
|
||||
end_value=self.decay_lr,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RsqrtDecaySchedule(LRScheduleConfig):
|
||||
"""Inverse square root decay schedule with warmup."""
|
||||
|
||||
warmup_steps: int = 1_000
|
||||
peak_lr: float = 5e-5
|
||||
timescale: float = 10_000
|
||||
|
||||
def create(self) -> optax.Schedule:
|
||||
return optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=self.peak_lr / (self.warmup_steps + 1),
|
||||
end_value=self.peak_lr,
|
||||
transition_steps=self.warmup_steps,
|
||||
),
|
||||
lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale),
|
||||
],
|
||||
[self.warmup_steps],
|
||||
)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WarmupConstantSchedule(LRScheduleConfig):
|
||||
"""Warmup constant schedule with warmup."""
|
||||
|
||||
warmup_steps: int = 2_000
|
||||
peak_lr: float = 5e-5
|
||||
|
||||
def create(self) -> optax.Schedule:
|
||||
return optax.warmup_constant_schedule(
|
||||
init_value=self.peak_lr / (self.warmup_steps + 1),
|
||||
peak_value=self.peak_lr,
|
||||
warmup_steps=self.warmup_steps,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OptimizerConfig(Protocol):
|
||||
def create(
|
||||
self,
|
||||
lr: optax.ScalarOrSchedule,
|
||||
weight_decay_mask: at.PyTree | None = None,
|
||||
) -> optax.GradientTransformation: ...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AdamW(OptimizerConfig):
|
||||
"""AdamW optimizer."""
|
||||
|
||||
b1: float = 0.9
|
||||
b2: float = 0.95
|
||||
eps: float = 1e-8
|
||||
# Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value.
|
||||
weight_decay: float = 1e-10
|
||||
clip_gradient_norm: float = 1.0
|
||||
|
||||
def create(
|
||||
self,
|
||||
lr: optax.ScalarOrSchedule,
|
||||
weight_decay_mask: at.PyTree | None = None,
|
||||
) -> optax.GradientTransformation:
|
||||
tx = optax.adamw(
|
||||
lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask
|
||||
)
|
||||
|
||||
return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SGD(OptimizerConfig):
|
||||
"""SGD optimizer."""
|
||||
|
||||
lr: float = 5e-5
|
||||
momentum: float = 0.9
|
||||
nesterov: bool = False
|
||||
|
||||
def create(
|
||||
self,
|
||||
lr: optax.ScalarOrSchedule,
|
||||
weight_decay_mask: at.PyTree | None = None,
|
||||
) -> optax.GradientTransformation:
|
||||
assert weight_decay_mask is None, "Weight decay is not supported for SGD"
|
||||
return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov)
|
||||
|
||||
|
||||
def create_optimizer(
|
||||
optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None
|
||||
) -> optax.GradientTransformation:
|
||||
lr = lr_schedule.create()
|
||||
return optimizer.create(lr, weight_decay_mask=weight_decay_mask)
|
||||
102
policy/openpi-InternData-A1/src/openpi/training/sharding.py
Normal file
102
policy/openpi-InternData-A1/src/openpi/training/sharding.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
BATCH_AXIS = "batch"
|
||||
FSDP_AXIS = "fsdp"
|
||||
# In FSDP, we shard the data across both the batch and FSDP axes.
|
||||
DATA_AXIS = (BATCH_AXIS, FSDP_AXIS)
|
||||
|
||||
|
||||
class _MeshState:
|
||||
active_mesh: jax.sharding.Mesh | None = None
|
||||
|
||||
|
||||
def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
|
||||
if jax.device_count() % num_fsdp_devices != 0:
|
||||
raise ValueError(
|
||||
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}."
|
||||
)
|
||||
mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)
|
||||
return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_mesh(mesh: jax.sharding.Mesh):
|
||||
"""Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a
|
||||
custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used
|
||||
in `activation_sharding_constraint` below."""
|
||||
if _MeshState.active_mesh is not None:
|
||||
raise ValueError("Cannot nest set_mesh context managers.")
|
||||
_MeshState.active_mesh = mesh
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_MeshState.active_mesh = None
|
||||
|
||||
|
||||
def activation_sharding_constraint(pytree):
|
||||
if _MeshState.active_mesh is None:
|
||||
return pytree
|
||||
return jax.lax.with_sharding_constraint(
|
||||
pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS))
|
||||
)
|
||||
|
||||
|
||||
def fsdp_sharding(
|
||||
pytree,
|
||||
mesh: jax.sharding.Mesh,
|
||||
*,
|
||||
min_size_mbytes: int = 4, # 4 MiB
|
||||
log: bool = False,
|
||||
):
|
||||
"""Apply FSDP sharding to a pytree of arrays based on the mesh shape.
|
||||
|
||||
Args:
|
||||
pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr)
|
||||
will be considered for sharding.
|
||||
mesh: The mesh being used for applying sharding on to pytree.
|
||||
min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this
|
||||
will be replicated.
|
||||
log: If true, will log the sharding decisions for arrays that are being considered for sharding.
|
||||
|
||||
Returns:
|
||||
The sharded pytree.
|
||||
"""
|
||||
min_size_bytes = min_size_mbytes * 2**20
|
||||
|
||||
def _shard_arr(kp, array: jax.ShapeDtypeStruct):
|
||||
# if fsdp is not actually going to be used, replicate everything to avoid extraneous logging
|
||||
if mesh.shape[FSDP_AXIS] == 1:
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
# replicate scalar and vector arrays
|
||||
if not hasattr(array, "shape"):
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
if len(array.shape) < 2:
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
# replicate small arrays
|
||||
if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes:
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension
|
||||
axes = np.argsort(array.shape)[::-1]
|
||||
spec = [None] * len(axes)
|
||||
for i in axes:
|
||||
if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:
|
||||
if log:
|
||||
logging.info(
|
||||
f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}"
|
||||
)
|
||||
spec[i] = FSDP_AXIS
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
|
||||
|
||||
# replicate if no valid sharding was found
|
||||
if log:
|
||||
logging.warning(
|
||||
f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}"
|
||||
)
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
return jax.tree_util.tree_map_with_path(_shard_arr, pytree)
|
||||
38
policy/openpi-InternData-A1/src/openpi/training/utils.py
Normal file
38
policy/openpi-InternData-A1/src/openpi/training/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from flax import nnx
|
||||
from flax import struct
|
||||
import jax
|
||||
import optax
|
||||
|
||||
from openpi.models import model as _model
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
|
||||
@at.typecheck
|
||||
@struct.dataclass
|
||||
class TrainState:
|
||||
step: at.Int[at.ArrayLike, ""]
|
||||
params: nnx.State
|
||||
model_def: nnx.GraphDef[_model.BaseModel]
|
||||
opt_state: optax.OptState
|
||||
tx: optax.GradientTransformation = struct.field(pytree_node=False)
|
||||
|
||||
ema_decay: float | None = struct.field(pytree_node=False)
|
||||
ema_params: nnx.State | None = None
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:
|
||||
"""Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert
|
||||
the leaf values to more meaningful strings.
|
||||
"""
|
||||
tree, _ = jax.tree_util.tree_flatten_with_path(tree)
|
||||
return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def array_tree_to_info(tree: at.PyTree) -> str:
|
||||
"""Converts a PyTree of arrays into a human-readable string for logging."""
|
||||
return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}")
|
||||
@@ -0,0 +1,103 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import re
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import flax.traverse_util
|
||||
import numpy as np
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.download as download
|
||||
from pathlib import Path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WeightLoader(Protocol):
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
"""Loads the model weights.
|
||||
|
||||
Args:
|
||||
params: Parameters of the model. This is a nested structure of array-like objects that
|
||||
represent the model's parameters.
|
||||
|
||||
Returns:
|
||||
Loaded parameters. The structure must be identical to `params`. If returning a subset of
|
||||
the parameters the loader must merge the loaded parameters with `params`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NoOpWeightLoader(WeightLoader):
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
return params
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CheckpointWeightLoader(WeightLoader):
|
||||
"""Loads an entire set of weights from a checkpoint.
|
||||
|
||||
Compatible with:
|
||||
trained checkpoints:
|
||||
example: "./checkpoints/<config>/<exp>/<step>/params"
|
||||
released checkpoints:
|
||||
example: "gs://openpi-assets/checkpoints/<model>/params"
|
||||
"""
|
||||
|
||||
params_path: str
|
||||
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
# We are loading np.ndarray and relying on the training code to properly convert and shard the params.
|
||||
loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray)
|
||||
# Add all missing LoRA weights.
|
||||
return _merge_params(loaded_params, params, missing_regex=".*lora.*")
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PaliGemmaWeightLoader(WeightLoader):
|
||||
"""Loads weights from the official PaliGemma checkpoint.
|
||||
|
||||
This will overwrite existing weights with similar names while keeping all extra weights intact.
|
||||
This allows us to support the action expert which is used by the Pi0 model.
|
||||
"""
|
||||
params_path: str
|
||||
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
path = Path(self.params_path)
|
||||
with path.open("rb") as f:
|
||||
flat_params = dict(np.load(f, allow_pickle=False))
|
||||
loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]}
|
||||
# Add all missing weights.
|
||||
return _merge_params(loaded_params, params, missing_regex=".*")
|
||||
|
||||
|
||||
|
||||
def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:
|
||||
"""Merges the loaded parameters with the reference parameters.
|
||||
|
||||
Args:
|
||||
loaded_params: The parameters to merge.
|
||||
params: The reference parameters.
|
||||
missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the merged parameters.
|
||||
"""
|
||||
flat_ref = flax.traverse_util.flatten_dict(params, sep="/")
|
||||
flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/")
|
||||
|
||||
# First, take all weights that are a subset of the reference weights.
|
||||
result = {}
|
||||
for k, v in flat_loaded.items():
|
||||
if k in flat_ref:
|
||||
result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
|
||||
|
||||
flat_loaded.clear()
|
||||
|
||||
# Then, merge any missing weights as defined by the missing regex.
|
||||
pattern = re.compile(missing_regex)
|
||||
for k in {k for k in flat_ref if pattern.fullmatch(k)}:
|
||||
if k not in result:
|
||||
result[k] = flat_ref[k]
|
||||
|
||||
return flax.traverse_util.unflatten_dict(result, sep="/")
|
||||
597
policy/openpi-InternData-A1/src/openpi/transforms.py
Normal file
597
policy/openpi-InternData-A1/src/openpi/transforms.py
Normal file
@@ -0,0 +1,597 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import Protocol, TypeAlias, TypeVar, runtime_checkable
|
||||
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
|
||||
from openpi.models import tokenizer as _tokenizer
|
||||
from openpi.shared import array_typing as at
|
||||
from openpi.shared import normalize as _normalize
|
||||
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
from pdb import set_trace
|
||||
DataDict: TypeAlias = at.PyTree
|
||||
NormStats: TypeAlias = _normalize.NormStats
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DataTransformFn(Protocol):
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
"""Apply transformation to the data.
|
||||
|
||||
Args:
|
||||
data: The data to apply the transform to. This is a possibly nested dictionary that contains
|
||||
unbatched data elements. Each leaf is expected to be a numpy array. Using JAX arrays is allowed
|
||||
but not recommended since it may result in extra GPU memory usage inside data loader worker
|
||||
processes.
|
||||
|
||||
Returns:
|
||||
The transformed data. Could be the input `data` that was modified in place, or a new data structure.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Group:
|
||||
"""A group of transforms."""
|
||||
|
||||
# Transforms that are applied to the model input data.
|
||||
inputs: Sequence[DataTransformFn] = ()
|
||||
|
||||
# Transforms that are applied to the model output data.
|
||||
outputs: Sequence[DataTransformFn] = ()
|
||||
|
||||
def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Sequence[DataTransformFn] = ()) -> "Group":
|
||||
"""Append transforms to the group and return a new group.
|
||||
|
||||
Args:
|
||||
inputs: Appended to the *end* of the current input transforms.
|
||||
outputs: Appended to the *beginning* of the current output transforms.
|
||||
|
||||
Returns:
|
||||
A new group with the appended transforms.
|
||||
"""
|
||||
return Group(inputs=(*self.inputs, *inputs), outputs=(*outputs, *self.outputs))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CompositeTransform(DataTransformFn):
|
||||
"""A composite transform that applies a sequence of transforms in order."""
|
||||
|
||||
transforms: Sequence[DataTransformFn]
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
for transform in self.transforms:
|
||||
data = transform(data)
|
||||
return data
|
||||
|
||||
|
||||
def compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn:
|
||||
"""Compose a sequence of transforms into a single transform."""
|
||||
return CompositeTransform(transforms)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RepackTransform(DataTransformFn):
|
||||
"""Repacks an input dictionary into a new dictionary.
|
||||
|
||||
Repacking is defined using a dictionary where the keys are the new keys and the values
|
||||
are the flattened paths to the old keys. We use '/' as the separator during flattening.
|
||||
|
||||
Example:
|
||||
{
|
||||
"images": {
|
||||
"cam_high": "observation.images.top",
|
||||
"cam_low": "observation.images.bottom",
|
||||
},
|
||||
"state": "observation.state",
|
||||
"actions": "action",
|
||||
}
|
||||
"""
|
||||
|
||||
structure: at.PyTree[str]
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
flat_item = flatten_dict(data)
|
||||
return jax.tree.map(lambda k: flat_item[k], self.structure)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ReTransform(DataTransformFn):
|
||||
"""Repacks an input dictionary into a new dictionary.
|
||||
|
||||
Repacking is defined using a dictionary where the keys are the new keys and the values
|
||||
are the flattened paths to the old keys. We use '/' as the separator during flattening.
|
||||
|
||||
Example:
|
||||
{
|
||||
"images": {
|
||||
"cam_high": "observation.images.top",
|
||||
"cam_low": "observation.images.bottom",
|
||||
},
|
||||
"state": "observation.state",
|
||||
"actions": "action",
|
||||
}
|
||||
"""
|
||||
|
||||
structure: at.PyTree[str]
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
flat_item = flatten_dict(data)
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
|
||||
return jax.tree.map(lambda k: flat_item[k], self.structure)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class InjectDefaultPrompt(DataTransformFn):
|
||||
prompt: str | None
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if self.prompt is not None and "prompt" not in data:
|
||||
data["prompt"] = np.asarray(self.prompt)
|
||||
return data
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Normalize(DataTransformFn):
|
||||
norm_stats: at.PyTree[NormStats] | None
|
||||
# If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
|
||||
use_quantiles: bool = False
|
||||
# If true, will raise an error if any of the keys in the norm stats are not present in the data.
|
||||
strict: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.norm_stats is not None and self.use_quantiles:
|
||||
_assert_quantile_stats(self.norm_stats)
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if self.norm_stats is None:
|
||||
return data
|
||||
|
||||
return apply_tree(
|
||||
data,
|
||||
self.norm_stats,
|
||||
self._normalize_quantile if self.use_quantiles else self._normalize,
|
||||
strict=self.strict,
|
||||
)
|
||||
|
||||
def _normalize(self, x, stats: NormStats):
|
||||
mean, std = stats.mean[..., : x.shape[-1]], stats.std[..., : x.shape[-1]]
|
||||
return (x - mean) / (std + 1e-6)
|
||||
|
||||
def _normalize_quantile(self, x, stats: NormStats):
|
||||
assert stats.q01 is not None
|
||||
assert stats.q99 is not None
|
||||
q01, q99 = stats.q01[..., : x.shape[-1]], stats.q99[..., : x.shape[-1]]
|
||||
return (x - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Unnormalize(DataTransformFn):
|
||||
norm_stats: at.PyTree[NormStats] | None
|
||||
# If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
|
||||
use_quantiles: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.norm_stats is not None and self.use_quantiles:
|
||||
_assert_quantile_stats(self.norm_stats)
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if self.norm_stats is None:
|
||||
return data
|
||||
|
||||
# Make sure that all the keys in the norm stats are present in the data.
|
||||
return apply_tree(
|
||||
data,
|
||||
self.norm_stats,
|
||||
self._unnormalize_quantile if self.use_quantiles else self._unnormalize,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
def _unnormalize(self, x, stats: NormStats):
|
||||
mean = pad_to_dim(stats.mean, x.shape[-1], axis=-1, value=0.0)
|
||||
std = pad_to_dim(stats.std, x.shape[-1], axis=-1, value=1.0)
|
||||
return x * (std + 1e-6) + mean
|
||||
|
||||
def _unnormalize_quantile(self, x, stats: NormStats):
|
||||
assert stats.q01 is not None
|
||||
assert stats.q99 is not None
|
||||
q01, q99 = stats.q01, stats.q99
|
||||
if (dim := q01.shape[-1]) < x.shape[-1]:
|
||||
return np.concatenate([(x[..., :dim] + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01, x[..., dim:]], axis=-1)
|
||||
return (x + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ResizeImages(DataTransformFn):
|
||||
height: int
|
||||
width: int
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
data["image"] = {k: image_tools.resize_with_pad(v, self.height, self.width) for k, v in data["image"].items()}
|
||||
return data
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SubsampleActions(DataTransformFn):
|
||||
stride: int
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
data["actions"] = data["actions"][:: self.stride]
|
||||
return data
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DeltaActions(DataTransformFn):
|
||||
"""Repacks absolute actions into delta action space."""
|
||||
|
||||
# Boolean mask for the action dimensions to be repacked into delta action space. Length
|
||||
# can be smaller than the actual number of dimensions. If None, this transform is a no-op.
|
||||
# See `make_bool_mask` for more details.
|
||||
mask: Sequence[bool] | None
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if "actions" not in data or self.mask is None:
|
||||
return data
|
||||
|
||||
state, actions = data["state"], data["actions"]
|
||||
mask = np.asarray(self.mask)
|
||||
dims = mask.shape[-1]
|
||||
actions[..., :dims] -= np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2)
|
||||
data["actions"] = actions
|
||||
|
||||
return data
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DeltaActionsPose(DataTransformFn):
|
||||
"""Repacks absolute actions into delta action space."""
|
||||
|
||||
# Boolean mask for the action dimensions to be repacked into delta action space. Length
|
||||
# can be smaller than the actual number of dimensions. If None, this transform is a no-op.
|
||||
# See `make_bool_mask` for more details.
|
||||
mask: Sequence[bool] | None
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
# set_trace()
|
||||
if "actions" not in data or self.mask is None:
|
||||
return data
|
||||
|
||||
pose, actions = data["pose"], data["actions"]
|
||||
mask = np.asarray(self.mask)
|
||||
dims = mask.shape[-1]
|
||||
act = actions[..., :dims]
|
||||
st = pose[..., :dims]
|
||||
pose_mask = mask[:6]
|
||||
if np.any(pose_mask):
|
||||
pose_action = act[..., :6]
|
||||
pose_state = st[..., :6]
|
||||
|
||||
if pose_action.ndim == 2:
|
||||
rel_list = []
|
||||
for i in range(pose_action.shape[0]):
|
||||
rel_list.append(relative_pose(pose_action[i], pose_state))
|
||||
rel_pose = np.stack(rel_list, axis=0)
|
||||
else:
|
||||
raise ValueError("pose_action must be dim 2")
|
||||
|
||||
act[..., :6] = rel_pose
|
||||
data["actions"][..., :dims] = act
|
||||
del data["pose"]
|
||||
return data
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AbsoluteActions(DataTransformFn):
|
||||
"""Repacks delta actions into absolute action space."""
|
||||
|
||||
# Boolean mask for the action dimensions to be repacked into absolute action space. Length
|
||||
# can be smaller than the actual number of dimensions. If None, this transform is a no-op.
|
||||
# See `make_bool_mask` for more details.
|
||||
mask: Sequence[bool] | None
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if "actions" not in data or self.mask is None:
|
||||
return data
|
||||
|
||||
state, actions = data["state"], data["actions"]
|
||||
mask = np.asarray(self.mask)
|
||||
dims = mask.shape[-1]
|
||||
actions[..., :dims] += np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2)
|
||||
data["actions"] = actions
|
||||
|
||||
return data
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AbsoluteActionsPose:
|
||||
"""Convert relative pose actions back into absolute pose actions."""
|
||||
|
||||
mask: Sequence[bool] | None
|
||||
|
||||
def __call__(self, data):
|
||||
if "actions" not in data or "pose" not in data or self.mask is None:
|
||||
return data
|
||||
|
||||
actions = data["actions"] # (T, D)
|
||||
pose = data["pose"] # (D,)
|
||||
mask = np.asarray(self.mask)
|
||||
dims = mask.shape[-1]
|
||||
|
||||
act = actions[..., :dims]
|
||||
st = pose[..., :dims]
|
||||
|
||||
pose_mask = mask[:6]
|
||||
if np.any(pose_mask):
|
||||
pose_action = act[..., :6]
|
||||
pose_state = st[..., :6]
|
||||
|
||||
abs_list = []
|
||||
for i in range(pose_action.shape[0]):
|
||||
abs_list.append(absolute_pose(pose_action[i], pose_state))
|
||||
abs_pose = np.stack(abs_list, axis=0)
|
||||
act[..., :6] = abs_pose
|
||||
|
||||
data["actions"][..., :dims] = act
|
||||
return data
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TokenizePrompt(DataTransformFn):
|
||||
tokenizer: _tokenizer.PaligemmaTokenizer
|
||||
discrete_state_input: bool = False
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if (prompt := data.pop("prompt", None)) is None:
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
if self.discrete_state_input:
|
||||
if (state := data.get("state", None)) is None:
|
||||
raise ValueError("State is required.")
|
||||
else:
|
||||
state = None
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
prompt = prompt.item()
|
||||
|
||||
tokens, token_masks = self.tokenizer.tokenize(prompt, state)
|
||||
return {**data, "tokenized_prompt": tokens, "tokenized_prompt_mask": token_masks}
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TokenizeFASTInputs(DataTransformFn):
|
||||
tokenizer: _tokenizer.FASTTokenizer
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if (prompt := data.pop("prompt", None)) is None:
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
prompt = prompt.item()
|
||||
|
||||
state, actions = data["state"], data.get("actions")
|
||||
tokens, token_mask, ar_mask, loss_mask = self.tokenizer.tokenize(prompt, state, actions)
|
||||
return {
|
||||
**data,
|
||||
"tokenized_prompt": tokens,
|
||||
"tokenized_prompt_mask": token_mask,
|
||||
"token_ar_mask": ar_mask,
|
||||
"token_loss_mask": loss_mask,
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ExtractFASTActions(DataTransformFn):
|
||||
tokenizer: _tokenizer.FASTTokenizer
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if "actions" not in data:
|
||||
return data
|
||||
# Model outputs are saved in "actions", but for FAST models they represent tokens.
|
||||
tokens = data.pop("actions")
|
||||
actions = self.tokenizer.extract_actions(tokens.astype(np.int32), self.action_horizon, self.action_dim)
|
||||
return {
|
||||
**data,
|
||||
"actions": actions,
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PromptFromLeRobotTask(DataTransformFn):
|
||||
"""Extracts a prompt from the current LeRobot dataset task."""
|
||||
|
||||
# Contains the LeRobot dataset tasks (dataset.meta.tasks).
|
||||
tasks: dict[int, str]
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
if "task_index" not in data:
|
||||
raise ValueError('Cannot extract prompt without "task_index"')
|
||||
|
||||
task_index = int(data["task_index"])
|
||||
if (prompt := self.tasks.get(task_index)) is None:
|
||||
raise ValueError(f"{task_index=} not found in task mapping: {self.tasks}")
|
||||
|
||||
return {**data, "prompt": prompt}
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PadStatesAndActions(DataTransformFn):
|
||||
"""Zero-pads states and actions to the model action dimension."""
|
||||
|
||||
model_action_dim: int
|
||||
|
||||
def __call__(self, data: DataDict) -> DataDict:
|
||||
data["state"] = pad_to_dim(data["state"], self.model_action_dim, axis=-1)
|
||||
if "actions" in data:
|
||||
data["actions"] = pad_to_dim(data["actions"], self.model_action_dim, axis=-1)
|
||||
return data
|
||||
|
||||
|
||||
def flatten_dict(tree: at.PyTree) -> dict:
|
||||
"""Flatten a nested dictionary. Uses '/' as the separator."""
|
||||
return traverse_util.flatten_dict(tree, sep="/")
|
||||
|
||||
|
||||
def unflatten_dict(tree: dict) -> at.PyTree:
|
||||
"""Unflatten a flattened dictionary. Assumes that '/' was used as a separator."""
|
||||
return traverse_util.unflatten_dict(tree, sep="/")
|
||||
|
||||
|
||||
def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree:
|
||||
"""Transform the structure of a nested dictionary using a set of patterns.
|
||||
|
||||
The transformation is defined using the `patterns` dictionary. The keys are the
|
||||
input keys that should be matched and the values are the new names inside the output
|
||||
dictionary. If the value is None, the input key is removed.
|
||||
|
||||
Both keys and values should represent flattened paths using '/' as the separator.
|
||||
Keys can be regular expressions and values can include backreferences to the
|
||||
matched groups (see `re.sub` for more details). Note that the regular expression
|
||||
must match the entire key.
|
||||
|
||||
The order inside the `patterns` dictionary is important. Only the first pattern that
|
||||
matches the input key will be used.
|
||||
|
||||
See unit tests for more examples.
|
||||
|
||||
Args:
|
||||
patterns: A mapping from old keys to new keys.
|
||||
tree: The nested dictionary to transform.
|
||||
|
||||
Returns:
|
||||
The transformed nested dictionary.
|
||||
"""
|
||||
data = flatten_dict(tree)
|
||||
|
||||
# Compile the patterns.
|
||||
compiled = {re.compile(k): v for k, v in patterns.items()}
|
||||
|
||||
output = {}
|
||||
for k in data:
|
||||
for pattern, repl in compiled.items():
|
||||
if pattern.fullmatch(k):
|
||||
new_k = pattern.sub(repl, k, count=1) if repl is not None else None
|
||||
break
|
||||
else:
|
||||
# Use the original key if no match is found.
|
||||
new_k = k
|
||||
|
||||
if new_k is not None:
|
||||
if new_k in output:
|
||||
raise ValueError(f"Key '{new_k}' already exists in output")
|
||||
output[new_k] = data[k]
|
||||
|
||||
# Validate the output structure to make sure that it can be unflattened.
|
||||
names = sorted(output)
|
||||
for i in range(len(names) - 1):
|
||||
name, next_name = names[i : i + 2]
|
||||
if next_name.startswith(name + "/"):
|
||||
raise ValueError(f"Leaf '{name}' aliases a node of '{next_name}'")
|
||||
|
||||
return unflatten_dict(output)
|
||||
|
||||
|
||||
def apply_tree(
|
||||
tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False
|
||||
) -> at.PyTree[T]:
|
||||
tree = flatten_dict(tree)
|
||||
selector = flatten_dict(selector)
|
||||
|
||||
def transform(k: str, v: T) -> T:
|
||||
if k in selector:
|
||||
return fn(v, selector[k])
|
||||
return v
|
||||
|
||||
if strict:
|
||||
for k in selector:
|
||||
if k not in tree:
|
||||
raise ValueError(f"Selector key {k} not found in tree")
|
||||
|
||||
return unflatten_dict({k: transform(k, v) for k, v in tree.items()})
|
||||
|
||||
|
||||
def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:
|
||||
"""Pad an array to the target dimension with zeros along the specified axis."""
|
||||
current_dim = x.shape[axis]
|
||||
if current_dim < target_dim:
|
||||
pad_width = [(0, 0)] * len(x.shape)
|
||||
pad_width[axis] = (0, target_dim - current_dim)
|
||||
return np.pad(x, pad_width, constant_values=value)
|
||||
return x
|
||||
|
||||
|
||||
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
|
||||
"""Make a boolean mask for the given dimensions.
|
||||
|
||||
Example:
|
||||
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
|
||||
Args:
|
||||
dims: The dimensions to make the mask for.
|
||||
|
||||
Returns:
|
||||
A tuple of booleans.
|
||||
"""
|
||||
result = []
|
||||
for dim in dims:
|
||||
if dim > 0:
|
||||
result.extend([True] * (dim))
|
||||
else:
|
||||
result.extend([False] * (-dim))
|
||||
return tuple(result)
|
||||
|
||||
|
||||
def _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None:
|
||||
for k, v in flatten_dict(norm_stats).items():
|
||||
if v.q01 is None or v.q99 is None:
|
||||
raise ValueError(
|
||||
f"quantile stats must be provided if use_quantile_norm is True. Key {k} is missing q01 or q99."
|
||||
)
|
||||
|
||||
|
||||
def pose6d_to_pose(pose6d, degrees=False):
|
||||
"""
|
||||
pose6d: (6,)
|
||||
return: (4, 4)
|
||||
"""
|
||||
pose = np.eye(4)
|
||||
pos = pose6d[:3]
|
||||
euler = pose6d[3:]
|
||||
R_mat = R.from_euler("xyz", euler, degrees=degrees).as_matrix()
|
||||
pose[:3, :3] = R_mat
|
||||
pose[:3, 3] = pos
|
||||
return pose
|
||||
|
||||
|
||||
def pose_to_6d(pose, degrees=False):
|
||||
"""
|
||||
pose: (4, 4)
|
||||
return: (6,)
|
||||
"""
|
||||
pos = pose[:3, 3]
|
||||
rot = pose[:3, :3]
|
||||
euler = R.from_matrix(rot).as_euler("xyz", degrees=degrees)
|
||||
return np.concatenate([pos, euler], axis=0)
|
||||
|
||||
|
||||
def relative_pose(pose_action, pose_state):
|
||||
pose_a = pose6d_to_pose(pose_action, degrees=False)
|
||||
pose_s = pose6d_to_pose(pose_state, degrees=False)
|
||||
rel = np.linalg.inv(pose_s) @ pose_a
|
||||
return pose_to_6d(rel, degrees=False)
|
||||
|
||||
def absolute_pose(pose_delta, pose_state):
|
||||
pose_d = pose6d_to_pose(pose_delta, degrees=False)
|
||||
pose_s = pose6d_to_pose(pose_state, degrees=False)
|
||||
abs_pose = pose_s @ pose_d
|
||||
return pose_to_6d(abs_pose, degrees=False)
|
||||
121
policy/openpi-InternData-A1/src/openpi/transforms_test.py
Normal file
121
policy/openpi-InternData-A1/src/openpi/transforms_test.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import openpi.models.tokenizer as _tokenizer
|
||||
import openpi.transforms as _transforms
|
||||
|
||||
|
||||
def test_repack_transform():
|
||||
transform = _transforms.RepackTransform(
|
||||
structure={
|
||||
"a": {"b": "b/c"},
|
||||
"d": "e/f",
|
||||
}
|
||||
)
|
||||
item = {"b": {"c": 1}, "e": {"f": 2}}
|
||||
assert transform(item) == {"a": {"b": 1}, "d": 2}
|
||||
|
||||
|
||||
def test_delta_actions():
|
||||
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
|
||||
|
||||
transform = _transforms.DeltaActions(mask=[False, True])
|
||||
transformed = transform(item)
|
||||
|
||||
assert np.all(transformed["state"] == np.array([1, 2, 3]))
|
||||
assert np.all(transformed["actions"] == np.array([[3, 2, 5], [5, 4, 7]]))
|
||||
|
||||
|
||||
def test_delta_actions_noop():
|
||||
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
|
||||
|
||||
# No-op when the mask is disabled.
|
||||
transform = _transforms.DeltaActions(mask=None)
|
||||
assert transform(item) is item
|
||||
|
||||
# No-op when there are no actions in the input.
|
||||
del item["actions"]
|
||||
transform = _transforms.DeltaActions(mask=[True, False])
|
||||
assert transform(item) is item
|
||||
|
||||
|
||||
def test_absolute_actions():
|
||||
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
|
||||
|
||||
transform = _transforms.AbsoluteActions(mask=[False, True])
|
||||
transformed = transform(item)
|
||||
|
||||
assert np.all(transformed["state"] == np.array([1, 2, 3]))
|
||||
assert np.all(transformed["actions"] == np.array([[3, 6, 5], [5, 8, 7]]))
|
||||
|
||||
|
||||
def test_absolute_actions_noop():
|
||||
item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])}
|
||||
|
||||
# No-op when the mask is disabled.
|
||||
transform = _transforms.AbsoluteActions(mask=None)
|
||||
assert transform(item) is item
|
||||
|
||||
# No-op when there are no actions in the input.
|
||||
del item["actions"]
|
||||
transform = _transforms.AbsoluteActions(mask=[True, False])
|
||||
assert transform(item) is item
|
||||
|
||||
|
||||
def test_make_bool_mask():
|
||||
assert _transforms.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
assert _transforms.make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
|
||||
|
||||
def test_tokenize_prompt():
|
||||
tokenizer = _tokenizer.PaligemmaTokenizer(max_len=12)
|
||||
transform = _transforms.TokenizePrompt(tokenizer)
|
||||
|
||||
data = transform({"prompt": "Hello, world!"})
|
||||
|
||||
tok_prompt, tok_mask = tokenizer.tokenize("Hello, world!")
|
||||
assert np.allclose(tok_prompt, data["tokenized_prompt"])
|
||||
assert np.allclose(tok_mask, data["tokenized_prompt_mask"])
|
||||
|
||||
|
||||
def test_tokenize_no_prompt():
|
||||
transform = _transforms.TokenizePrompt(_tokenizer.PaligemmaTokenizer())
|
||||
|
||||
with pytest.raises(ValueError, match="Prompt is required"):
|
||||
transform({})
|
||||
|
||||
|
||||
def test_transform_dict():
|
||||
# Rename and remove keys.
|
||||
input = {"a": {"b": 1, "c": 2}}
|
||||
output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input)
|
||||
assert output == {"a": {"c": 1}}
|
||||
|
||||
# Raises and error since the renamed key conflicts with an existing key.
|
||||
with pytest.raises(ValueError, match="Key 'a/c' already exists in output"):
|
||||
_transforms.transform_dict({"a/b": "a/c"}, input)
|
||||
|
||||
# Full match is required and so nothing will be removed.
|
||||
input = {"a": {"b": 1, "c": 2}}
|
||||
output = _transforms.transform_dict({"a": None}, input)
|
||||
assert output == input
|
||||
|
||||
# The regex matches the entire key and so the entire input will be removed.
|
||||
input = {"a": {"b": 1, "c": 2}}
|
||||
output = _transforms.transform_dict({"a.+": None}, input)
|
||||
assert output == {}
|
||||
|
||||
# Replace keys using backreferences. All leaves named 'c' are replaced with 'd'.
|
||||
input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}}
|
||||
output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input)
|
||||
assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}}
|
||||
|
||||
|
||||
def test_extract_prompt_from_task():
|
||||
transform = _transforms.PromptFromLeRobotTask({1: "Hello, world!"})
|
||||
|
||||
data = transform({"task_index": 1})
|
||||
assert data["prompt"] == "Hello, world!"
|
||||
|
||||
with pytest.raises(ValueError, match="task_index=2 not found in task mapping"):
|
||||
transform({"task_index": 2})
|
||||
Reference in New Issue
Block a user