multi-node openpi commit
This commit is contained in:
459
policy/openpi-InternData-A1/src/openpi/models/gemma.py
Normal file
459
policy/openpi-InternData-A1/src/openpi/models/gemma.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gemma adaptation for Pi, taken from big_vision.
|
||||
|
||||
We follow this einsum axis naming convention:
|
||||
B: batch
|
||||
T: query length
|
||||
S: k/v length
|
||||
N: num query heads
|
||||
K: num k/v heads
|
||||
G: num query heads per k/v head
|
||||
H: head dim
|
||||
D: d_model ("features")
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.training.sharding as sharding
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257_152
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant: Variant) -> Config:
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "dummy":
|
||||
return Config(
|
||||
width=64,
|
||||
depth=4,
|
||||
mlp_dim=128,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=16,
|
||||
)
|
||||
if variant == "gemma_300m":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
|
||||
)
|
||||
if variant == "gemma_300m_lora":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x, cond):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
if cond is None:
|
||||
# regular RMSNorm
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype), None # return in original dtype
|
||||
|
||||
# adaptive RMSNorm
|
||||
modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
|
||||
scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
|
||||
normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
|
||||
return normed_inputs.astype(dtype), gate
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.normal(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
configs: Sequence[Config]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, positions, attn_mask, kv_cache):
|
||||
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
|
||||
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
|
||||
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
|
||||
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
|
||||
|
||||
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
|
||||
|
||||
qkvs = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is None:
|
||||
continue
|
||||
if config.num_kv_heads == config.num_heads:
|
||||
qkv_einsum = lora.Einsum(
|
||||
shape=(3, config.num_heads, config.width, config.head_dim),
|
||||
name=_name("qkv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
|
||||
else:
|
||||
q_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.width, config.head_dim),
|
||||
name=_name("q_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
q = q_einsum("BTD,NDH->BTNH", x)
|
||||
kv_einsum = lora.Einsum(
|
||||
shape=(2, config.num_kv_heads, config.width, config.head_dim),
|
||||
name=_name("kv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
qkvs.append((q, k, v))
|
||||
|
||||
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
|
||||
|
||||
q = _apply_rope(q, positions=positions)
|
||||
q *= self.configs[0].head_dim ** -0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions)
|
||||
|
||||
# should still be half-precision here (if input was half-precision)
|
||||
assert q.dtype == k.dtype == v.dtype == dtype
|
||||
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
k = jnp.concatenate([cache_k, k], axis=1)
|
||||
v = jnp.concatenate([cache_v, v], axis=1)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
|
||||
out = []
|
||||
start = 0
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
end = start + x.shape[1]
|
||||
out_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.head_dim, config.width),
|
||||
name=_name("attn_vec_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
|
||||
start = end
|
||||
else:
|
||||
out.append(None)
|
||||
|
||||
return out, (k, v)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
configs: tuple[Config, ...]
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
|
||||
|
||||
attn = Attention(configs=self.configs, name="attn")
|
||||
|
||||
pre_attn = []
|
||||
gates = []
|
||||
for i, x in enumerate(xs):
|
||||
if x is not None:
|
||||
x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
||||
pre_attn.append(x)
|
||||
gates.append(gate if x is not None else None)
|
||||
|
||||
pre_attn = sharding.activation_sharding_constraint(pre_attn)
|
||||
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
|
||||
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
|
||||
post_attn = sharding.activation_sharding_constraint(post_attn)
|
||||
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
out = []
|
||||
gates = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
||||
x = lora.FeedForward( # noqa: PLW2901
|
||||
features=config.width,
|
||||
hidden_dim=config.mlp_dim,
|
||||
name=_name("mlp", i),
|
||||
lora_config=config.lora_configs.get("ffn"),
|
||||
)(x)
|
||||
out.append(x)
|
||||
gates.append(gate if x is not None else None)
|
||||
|
||||
out = sharding.activation_sharding_constraint(out)
|
||||
out = jax.tree.map(lambda x: drop(x, deterministic), out)
|
||||
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
return xs, kv_cache
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""Transformer model, supporting a mixture of different weights for different tokens."""
|
||||
|
||||
configs: Sequence[Config] # list of configs, one for each expert
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
adarms: bool = False
|
||||
|
||||
def setup(self):
|
||||
# all experts must have the same depth
|
||||
assert all(config.depth == self.configs[0].depth for config in self.configs)
|
||||
|
||||
self.embedder = Embedder(
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
embed_dim=self.configs[0].width, # embedder for first expert only
|
||||
name="embedder",
|
||||
)
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=False,
|
||||
static_argnums=(5,), # 0=self, 6=deterministic
|
||||
policy=jax.checkpoint_policies.nothing_saveable,
|
||||
)
|
||||
self.layers = nn.scan(
|
||||
block_cls,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=(
|
||||
0,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
|
||||
length=self.configs[0].depth,
|
||||
)(
|
||||
configs=self.configs,
|
||||
dropout=self.dropout,
|
||||
dropout_bdims=self.dropout_bdims,
|
||||
)
|
||||
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
|
||||
|
||||
@at.typecheck
|
||||
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
|
||||
return self.embedder.encode(tokens).astype(self.embed_dtype)
|
||||
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
# list of token arrays, one for each expert, or None if that expert should not be run
|
||||
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
|
||||
positions: at.Int[at.Array, "b t"],
|
||||
mask: at.Bool[at.Array, "b t s"],
|
||||
adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
|
||||
*,
|
||||
kv_cache: KVCache | None = None,
|
||||
deterministic: bool = True,
|
||||
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
|
||||
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
|
||||
mask = jnp.asarray(mask)[:, None, :, :]
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None] * len(self.configs)
|
||||
|
||||
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
|
||||
|
||||
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
|
||||
|
||||
return [
|
||||
f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
|
||||
], kv_cache
|
||||
|
||||
def init(self, use_adarms: Sequence[bool]):
|
||||
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
||||
self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
|
||||
self(
|
||||
[jnp.zeros((1, 1, c.width)) for c in self.configs],
|
||||
jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
|
||||
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
|
||||
adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
|
||||
)
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
|
||||
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
|
||||
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
|
||||
# here.
|
||||
return res.astype(x.dtype)
|
||||
|
||||
|
||||
def _name(name, i):
|
||||
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
|
||||
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
|
||||
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
|
||||
# and the action expert.
|
||||
if i == 0:
|
||||
return name
|
||||
return f"{name}_{i}"
|
||||
|
||||
|
||||
def _gated_residual(x, y, gate):
|
||||
assert (x is None) == (y is None)
|
||||
if x is None:
|
||||
return None
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
Reference in New Issue
Block a user