added lora fast model support (#274)
* added lora fast model support * small config mistake * change to ConfigDict types instead of Any as suggested in PR #274 discussion https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945632119 * Simplify get_freeze_filter as per comment on PR #274 https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * actually pass the configs haha https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * update test to check if lora params are present https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * Fixed test to use nnx filters so that it is more clean * run formatter
This commit is contained in:
@@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif
|
||||
Used for FAST autoregressive policies.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
@@ -25,9 +26,10 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
Variant = Literal["gemma_2b"]
|
||||
Variant = Literal["gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant):
|
||||
@@ -48,6 +50,26 @@ def get_config(variant):
|
||||
"remat_policy": "nothing_saveable",
|
||||
}
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
"lora_configs": {
|
||||
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
},
|
||||
}
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@@ -110,21 +132,34 @@ class Attention(nn.Module):
|
||||
|
||||
cache_dtype: str | None = None
|
||||
|
||||
lora_config: lora.LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
self.qkv_einsum = Einsum(
|
||||
self.qkv_einsum = lora.Einsum(
|
||||
shape=(3, self.num_heads, self.features, self.head_dim),
|
||||
name="qkv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
else:
|
||||
# MQA
|
||||
self.q_einsum = Einsum(
|
||||
self.q_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.features, self.head_dim),
|
||||
name="q_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.kv_einsum = Einsum(
|
||||
self.kv_einsum = lora.Einsum(
|
||||
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
||||
name="kv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.attn_vec_einsum = Einsum(
|
||||
self.attn_vec_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.head_dim, self.features),
|
||||
name="attn_vec_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
|
||||
def _init_cache(self, k, v, cache_size):
|
||||
@@ -189,37 +224,6 @@ class Attention(nn.Module):
|
||||
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.zeros_init(),
|
||||
((2, self.features, self.hidden_dim)),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.zeros_init(),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
@@ -233,6 +237,7 @@ class Block(nn.Module):
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
cache_dtype: str | None = None
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
def setup(self):
|
||||
self.pre_attention_norm = RMSNorm()
|
||||
@@ -242,9 +247,12 @@ class Block(nn.Module):
|
||||
features=self.embed_dim,
|
||||
head_dim=self.head_dim,
|
||||
cache_dtype=self.cache_dtype,
|
||||
lora_config=self.lora_configs.get("attn"),
|
||||
)
|
||||
self.pre_ffw_norm = RMSNorm()
|
||||
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
|
||||
self.mlp = lora.FeedForward(
|
||||
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
|
||||
)
|
||||
if self.dropout:
|
||||
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
||||
else:
|
||||
@@ -289,6 +297,7 @@ class Module(nn.Module):
|
||||
|
||||
scan: bool = False
|
||||
remat_policy: str = "none"
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
@@ -380,6 +389,7 @@ class Module(nn.Module):
|
||||
"dropout": self.dropout,
|
||||
"dropout_bdims": self.dropout_bdims,
|
||||
"cache_dtype": self.cache_dtype,
|
||||
"lora_configs": self.lora_configs,
|
||||
}
|
||||
layers = self.scope.push("layers")
|
||||
blocks = [
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from flax import nnx
|
||||
import jax
|
||||
import pytest
|
||||
|
||||
@@ -53,6 +54,27 @@ def test_pi0_fast_model():
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
|
||||
def test_pi0_fast_lora_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size,)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
lora_filter = nnx_utils.PathRegex(".*lora.*")
|
||||
model_state = nnx.state(model)
|
||||
|
||||
lora_state_elems = list(model_state.filter(lora_filter))
|
||||
assert len(lora_state_elems) > 0
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_model_restore():
|
||||
key = jax.random.key(0)
|
||||
|
||||
@@ -12,6 +12,7 @@ from openpi.models import model as _model
|
||||
import openpi.models.gemma_fast as _gemma
|
||||
import openpi.models.siglip as _siglip
|
||||
from openpi.shared import array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
@@ -117,6 +118,12 @@ class Pi0FASTConfig(_model.BaseModelConfig):
|
||||
|
||||
return observation_spec, action_spec
|
||||
|
||||
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
||||
"""Returns the freeze filter based on the model config."""
|
||||
if "lora" in self.paligemma_variant:
|
||||
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
|
||||
return nnx.Nothing
|
||||
|
||||
|
||||
class Pi0FAST(_model.BaseModel):
|
||||
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
|
||||
|
||||
@@ -485,6 +485,23 @@ _CONFIGS = [
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
|
||||
num_train_steps=30_000,
|
||||
),
|
||||
TrainConfig(
|
||||
name="pi0_fast_libero_low_mem_finetune",
|
||||
model=pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora"),
|
||||
data=LeRobotLiberoDataConfig(
|
||||
repo_id="physical-intelligence/libero",
|
||||
base_config=DataConfig(
|
||||
local_files_only=False, # Set to True for local-only datasets.
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
|
||||
num_train_steps=30_000,
|
||||
freeze_filter=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
|
||||
).get_freeze_filter(),
|
||||
ema_decay=None,
|
||||
),
|
||||
#
|
||||
# Fine-tuning Aloha configs.
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user