added lora fast model support (#274)

* added lora fast model support

* small config mistake

* change to ConfigDict types instead of Any as suggested in PR #274 discussion https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945632119

* Simplify get_freeze_filter as per comment on PR #274
https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* actually pass the configs haha https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* update test to check if lora params are present https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* Fixed test to use nnx filters so that it is more clean

* run formatter
This commit is contained in:
Misha Lvovsky
2025-02-07 14:58:01 -05:00
committed by GitHub
parent 2a13ed7eff
commit 007e2b91ed
4 changed files with 94 additions and 38 deletions

View File

@@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif
Used for FAST autoregressive policies. Used for FAST autoregressive policies.
""" """
import dataclasses
from typing import Literal, TypeAlias from typing import Literal, TypeAlias
import einops import einops
@@ -25,9 +26,10 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import ml_collections import ml_collections
import openpi.models.lora as lora
import openpi.shared.array_typing as at import openpi.shared.array_typing as at
Variant = Literal["gemma_2b"] Variant = Literal["gemma_2b", "gemma_2b_lora"]
def get_config(variant): def get_config(variant):
@@ -48,6 +50,26 @@ def get_config(variant):
"remat_policy": "nothing_saveable", "remat_policy": "nothing_saveable",
} }
) )
if variant == "gemma_2b_lora":
return ml_collections.ConfigDict(
{
"variant": variant,
"width": 2048,
"depth": 18,
"mlp_dim": 16_384,
"num_heads": 8,
"num_kv_heads": 1,
"head_dim": 256,
"norm_eps": 1e-6,
"vocab_size": 257_152,
"scan": True,
"remat_policy": "nothing_saveable",
"lora_configs": {
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
},
}
)
raise ValueError(f"Unknown variant: {variant}") raise ValueError(f"Unknown variant: {variant}")
@@ -110,21 +132,34 @@ class Attention(nn.Module):
cache_dtype: str | None = None cache_dtype: str | None = None
lora_config: lora.LoRAConfig | None = None
def setup(self): def setup(self):
if self.num_kv_heads == self.num_heads: if self.num_kv_heads == self.num_heads:
self.qkv_einsum = Einsum( self.qkv_einsum = lora.Einsum(
shape=(3, self.num_heads, self.features, self.head_dim), shape=(3, self.num_heads, self.features, self.head_dim),
name="qkv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
) )
else: else:
# MQA self.q_einsum = lora.Einsum(
self.q_einsum = Einsum(
shape=(self.num_heads, self.features, self.head_dim), shape=(self.num_heads, self.features, self.head_dim),
name="q_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
) )
self.kv_einsum = Einsum( self.kv_einsum = lora.Einsum(
shape=(2, self.num_kv_heads, self.features, self.head_dim), shape=(2, self.num_kv_heads, self.features, self.head_dim),
name="kv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
) )
self.attn_vec_einsum = Einsum( self.attn_vec_einsum = lora.Einsum(
shape=(self.num_heads, self.head_dim, self.features), shape=(self.num_heads, self.head_dim, self.features),
name="attn_vec_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
) )
def _init_cache(self, k, v, cache_size): def _init_cache(self, k, v, cache_size):
@@ -189,37 +224,6 @@ class Attention(nn.Module):
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
@at.typecheck
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
w_gating = self.param(
"gating_einsum",
nn.initializers.zeros_init(),
((2, self.features, self.hidden_dim)),
).astype(dtype)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
"linear",
nn.initializers.zeros_init(),
(self.hidden_dim, self.features),
).astype(dtype)
outputs = jnp.dot(activations, w_linear)
assert outputs.dtype == dtype
return outputs
@at.typecheck @at.typecheck
class Block(nn.Module): class Block(nn.Module):
"""Transformer block.""" """Transformer block."""
@@ -233,6 +237,7 @@ class Block(nn.Module):
dropout: float = 0.0 dropout: float = 0.0
dropout_bdims: tuple[int, ...] = () dropout_bdims: tuple[int, ...] = ()
cache_dtype: str | None = None cache_dtype: str | None = None
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
def setup(self): def setup(self):
self.pre_attention_norm = RMSNorm() self.pre_attention_norm = RMSNorm()
@@ -242,9 +247,12 @@ class Block(nn.Module):
features=self.embed_dim, features=self.embed_dim,
head_dim=self.head_dim, head_dim=self.head_dim,
cache_dtype=self.cache_dtype, cache_dtype=self.cache_dtype,
lora_config=self.lora_configs.get("attn"),
) )
self.pre_ffw_norm = RMSNorm() self.pre_ffw_norm = RMSNorm()
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim) self.mlp = lora.FeedForward(
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
)
if self.dropout: if self.dropout:
self.drop = nn.Dropout(self.dropout, self.dropout_bdims) self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
else: else:
@@ -289,6 +297,7 @@ class Module(nn.Module):
scan: bool = False scan: bool = False
remat_policy: str = "none" remat_policy: str = "none"
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
@nn.compact @nn.compact
def __call__( def __call__(
@@ -380,6 +389,7 @@ class Module(nn.Module):
"dropout": self.dropout, "dropout": self.dropout,
"dropout_bdims": self.dropout_bdims, "dropout_bdims": self.dropout_bdims,
"cache_dtype": self.cache_dtype, "cache_dtype": self.cache_dtype,
"lora_configs": self.lora_configs,
} }
layers = self.scope.push("layers") layers = self.scope.push("layers")
blocks = [ blocks = [

View File

@@ -1,3 +1,4 @@
from flax import nnx
import jax import jax
import pytest import pytest
@@ -53,6 +54,27 @@ def test_pi0_fast_model():
assert actions.shape == (batch_size, 256) assert actions.shape == (batch_size, 256)
def test_pi0_fast_lora_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
lora_filter = nnx_utils.PathRegex(".*lora.*")
model_state = nnx.state(model)
lora_state_elems = list(model_state.filter(lora_filter))
assert len(lora_state_elems) > 0
@pytest.mark.manual @pytest.mark.manual
def test_model_restore(): def test_model_restore():
key = jax.random.key(0) key = jax.random.key(0)

View File

@@ -12,6 +12,7 @@ from openpi.models import model as _model
import openpi.models.gemma_fast as _gemma import openpi.models.gemma_fast as _gemma
import openpi.models.siglip as _siglip import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
logger = logging.getLogger("openpi") logger = logging.getLogger("openpi")
@@ -117,6 +118,12 @@ class Pi0FASTConfig(_model.BaseModelConfig):
return observation_spec, action_spec return observation_spec, action_spec
def get_freeze_filter(self) -> nnx.filterlib.Filter:
"""Returns the freeze filter based on the model config."""
if "lora" in self.paligemma_variant:
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
return nnx.Nothing
class Pi0FAST(_model.BaseModel): class Pi0FAST(_model.BaseModel):
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs): def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):

View File

@@ -485,6 +485,23 @@ _CONFIGS = [
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"), weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=30_000, num_train_steps=30_000,
), ),
TrainConfig(
name="pi0_fast_libero_low_mem_finetune",
model=pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora"),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(
local_files_only=False, # Set to True for local-only datasets.
prompt_from_task=True,
),
),
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=30_000,
freeze_filter=pi0_fast.Pi0FASTConfig(
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
).get_freeze_filter(),
ema_decay=None,
),
# #
# Fine-tuning Aloha configs. # Fine-tuning Aloha configs.
# #