added lora fast model support (#274)

* added lora fast model support

* small config mistake

* change to ConfigDict types instead of Any as suggested in PR #274 discussion https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945632119

* Simplify get_freeze_filter as per comment on PR #274
https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* actually pass the configs haha https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* update test to check if lora params are present https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* Fixed test to use nnx filters so that it is more clean

* run formatter
This commit is contained in:
Misha Lvovsky
2025-02-07 14:58:01 -05:00
committed by GitHub
parent 2a13ed7eff
commit 007e2b91ed
4 changed files with 94 additions and 38 deletions

View File

@@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif
Used for FAST autoregressive policies.
"""
import dataclasses
from typing import Literal, TypeAlias
import einops
@@ -25,9 +26,10 @@ import jax
import jax.numpy as jnp
import ml_collections
import openpi.models.lora as lora
import openpi.shared.array_typing as at
Variant = Literal["gemma_2b"]
Variant = Literal["gemma_2b", "gemma_2b_lora"]
def get_config(variant):
@@ -48,6 +50,26 @@ def get_config(variant):
"remat_policy": "nothing_saveable",
}
)
if variant == "gemma_2b_lora":
return ml_collections.ConfigDict(
{
"variant": variant,
"width": 2048,
"depth": 18,
"mlp_dim": 16_384,
"num_heads": 8,
"num_kv_heads": 1,
"head_dim": 256,
"norm_eps": 1e-6,
"vocab_size": 257_152,
"scan": True,
"remat_policy": "nothing_saveable",
"lora_configs": {
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
},
}
)
raise ValueError(f"Unknown variant: {variant}")
@@ -110,21 +132,34 @@ class Attention(nn.Module):
cache_dtype: str | None = None
lora_config: lora.LoRAConfig | None = None
def setup(self):
if self.num_kv_heads == self.num_heads:
self.qkv_einsum = Einsum(
self.qkv_einsum = lora.Einsum(
shape=(3, self.num_heads, self.features, self.head_dim),
name="qkv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
else:
# MQA
self.q_einsum = Einsum(
self.q_einsum = lora.Einsum(
shape=(self.num_heads, self.features, self.head_dim),
name="q_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
self.kv_einsum = Einsum(
self.kv_einsum = lora.Einsum(
shape=(2, self.num_kv_heads, self.features, self.head_dim),
name="kv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
self.attn_vec_einsum = Einsum(
self.attn_vec_einsum = lora.Einsum(
shape=(self.num_heads, self.head_dim, self.features),
name="attn_vec_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
def _init_cache(self, k, v, cache_size):
@@ -189,37 +224,6 @@ class Attention(nn.Module):
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
@at.typecheck
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
w_gating = self.param(
"gating_einsum",
nn.initializers.zeros_init(),
((2, self.features, self.hidden_dim)),
).astype(dtype)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
"linear",
nn.initializers.zeros_init(),
(self.hidden_dim, self.features),
).astype(dtype)
outputs = jnp.dot(activations, w_linear)
assert outputs.dtype == dtype
return outputs
@at.typecheck
class Block(nn.Module):
"""Transformer block."""
@@ -233,6 +237,7 @@ class Block(nn.Module):
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = ()
cache_dtype: str | None = None
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
def setup(self):
self.pre_attention_norm = RMSNorm()
@@ -242,9 +247,12 @@ class Block(nn.Module):
features=self.embed_dim,
head_dim=self.head_dim,
cache_dtype=self.cache_dtype,
lora_config=self.lora_configs.get("attn"),
)
self.pre_ffw_norm = RMSNorm()
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
self.mlp = lora.FeedForward(
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
)
if self.dropout:
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
else:
@@ -289,6 +297,7 @@ class Module(nn.Module):
scan: bool = False
remat_policy: str = "none"
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
@nn.compact
def __call__(
@@ -380,6 +389,7 @@ class Module(nn.Module):
"dropout": self.dropout,
"dropout_bdims": self.dropout_bdims,
"cache_dtype": self.cache_dtype,
"lora_configs": self.lora_configs,
}
layers = self.scope.push("layers")
blocks = [

View File

@@ -1,3 +1,4 @@
from flax import nnx
import jax
import pytest
@@ -53,6 +54,27 @@ def test_pi0_fast_model():
assert actions.shape == (batch_size, 256)
def test_pi0_fast_lora_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
lora_filter = nnx_utils.PathRegex(".*lora.*")
model_state = nnx.state(model)
lora_state_elems = list(model_state.filter(lora_filter))
assert len(lora_state_elems) > 0
@pytest.mark.manual
def test_model_restore():
key = jax.random.key(0)

View File

@@ -12,6 +12,7 @@ from openpi.models import model as _model
import openpi.models.gemma_fast as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
logger = logging.getLogger("openpi")
@@ -117,6 +118,12 @@ class Pi0FASTConfig(_model.BaseModelConfig):
return observation_spec, action_spec
def get_freeze_filter(self) -> nnx.filterlib.Filter:
"""Returns the freeze filter based on the model config."""
if "lora" in self.paligemma_variant:
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
return nnx.Nothing
class Pi0FAST(_model.BaseModel):
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):

View File

@@ -485,6 +485,23 @@ _CONFIGS = [
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=30_000,
),
TrainConfig(
name="pi0_fast_libero_low_mem_finetune",
model=pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora"),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(
local_files_only=False, # Set to True for local-only datasets.
prompt_from_task=True,
),
),
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=30_000,
freeze_filter=pi0_fast.Pi0FASTConfig(
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
).get_freeze_filter(),
ema_decay=None,
),
#
# Fine-tuning Aloha configs.
#