From 007e2b91edf02332fb458c9f987dc70012c19af1 Mon Sep 17 00:00:00 2001 From: Misha Lvovsky <35537543+mishmish66@users.noreply.github.com> Date: Fri, 7 Feb 2025 14:58:01 -0500 Subject: [PATCH] added lora fast model support (#274) * added lora fast model support * small config mistake * change to ConfigDict types instead of Any as suggested in PR #274 discussion https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945632119 * Simplify get_freeze_filter as per comment on PR #274 https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * actually pass the configs haha https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * update test to check if lora params are present https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * Fixed test to use nnx filters so that it is more clean * run formatter --- src/openpi/models/gemma_fast.py | 86 ++++++++++++++++++--------------- src/openpi/models/model_test.py | 22 +++++++++ src/openpi/models/pi0_fast.py | 7 +++ src/openpi/training/config.py | 17 +++++++ 4 files changed, 94 insertions(+), 38 deletions(-) diff --git a/src/openpi/models/gemma_fast.py b/src/openpi/models/gemma_fast.py index 7a568c6..eee39b4 100644 --- a/src/openpi/models/gemma_fast.py +++ b/src/openpi/models/gemma_fast.py @@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif Used for FAST autoregressive policies. """ +import dataclasses from typing import Literal, TypeAlias import einops @@ -25,9 +26,10 @@ import jax import jax.numpy as jnp import ml_collections +import openpi.models.lora as lora import openpi.shared.array_typing as at -Variant = Literal["gemma_2b"] +Variant = Literal["gemma_2b", "gemma_2b_lora"] def get_config(variant): @@ -48,6 +50,26 @@ def get_config(variant): "remat_policy": "nothing_saveable", } ) + if variant == "gemma_2b_lora": + return ml_collections.ConfigDict( + { + "variant": variant, + "width": 2048, + "depth": 18, + "mlp_dim": 16_384, + "num_heads": 8, + "num_kv_heads": 1, + "head_dim": 256, + "norm_eps": 1e-6, + "vocab_size": 257_152, + "scan": True, + "remat_policy": "nothing_saveable", + "lora_configs": { + "attn": lora.LoRAConfig(rank=16, alpha=16.0), + "ffn": lora.LoRAConfig(rank=16, alpha=16.0), + }, + } + ) raise ValueError(f"Unknown variant: {variant}") @@ -110,21 +132,34 @@ class Attention(nn.Module): cache_dtype: str | None = None + lora_config: lora.LoRAConfig | None = None + def setup(self): if self.num_kv_heads == self.num_heads: - self.qkv_einsum = Einsum( + self.qkv_einsum = lora.Einsum( shape=(3, self.num_heads, self.features, self.head_dim), + name="qkv_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=self.lora_config, ) else: - # MQA - self.q_einsum = Einsum( + self.q_einsum = lora.Einsum( shape=(self.num_heads, self.features, self.head_dim), + name="q_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + lora_config=self.lora_config, ) - self.kv_einsum = Einsum( + self.kv_einsum = lora.Einsum( shape=(2, self.num_kv_heads, self.features, self.head_dim), + name="kv_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=self.lora_config, ) - self.attn_vec_einsum = Einsum( + self.attn_vec_einsum = lora.Einsum( shape=(self.num_heads, self.head_dim, self.features), + name="attn_vec_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + lora_config=self.lora_config, ) def _init_cache(self, k, v, cache_size): @@ -189,37 +224,6 @@ class Attention(nn.Module): return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache -@at.typecheck -class FeedForward(nn.Module): - """Feed forward module.""" - - features: int - hidden_dim: int - - @nn.compact - def __call__(self, x): - dtype = x.dtype # original dtype, could be half-precision - w_gating = self.param( - "gating_einsum", - nn.initializers.zeros_init(), - ((2, self.features, self.hidden_dim)), - ).astype(dtype) - ff_gate = jnp.dot(x, w_gating[0]) - gate_value = nn.gelu(ff_gate) - - ff1 = jnp.dot(x, w_gating[1]) - activations = gate_value * ff1 - - w_linear = self.param( - "linear", - nn.initializers.zeros_init(), - (self.hidden_dim, self.features), - ).astype(dtype) - outputs = jnp.dot(activations, w_linear) - assert outputs.dtype == dtype - return outputs - - @at.typecheck class Block(nn.Module): """Transformer block.""" @@ -233,6 +237,7 @@ class Block(nn.Module): dropout: float = 0.0 dropout_bdims: tuple[int, ...] = () cache_dtype: str | None = None + lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) def setup(self): self.pre_attention_norm = RMSNorm() @@ -242,9 +247,12 @@ class Block(nn.Module): features=self.embed_dim, head_dim=self.head_dim, cache_dtype=self.cache_dtype, + lora_config=self.lora_configs.get("attn"), ) self.pre_ffw_norm = RMSNorm() - self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim) + self.mlp = lora.FeedForward( + features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn") + ) if self.dropout: self.drop = nn.Dropout(self.dropout, self.dropout_bdims) else: @@ -289,6 +297,7 @@ class Module(nn.Module): scan: bool = False remat_policy: str = "none" + lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) @nn.compact def __call__( @@ -380,6 +389,7 @@ class Module(nn.Module): "dropout": self.dropout, "dropout_bdims": self.dropout_bdims, "cache_dtype": self.cache_dtype, + "lora_configs": self.lora_configs, } layers = self.scope.push("layers") blocks = [ diff --git a/src/openpi/models/model_test.py b/src/openpi/models/model_test.py index 35ef186..897b41a 100644 --- a/src/openpi/models/model_test.py +++ b/src/openpi/models/model_test.py @@ -1,3 +1,4 @@ +from flax import nnx import jax import pytest @@ -53,6 +54,27 @@ def test_pi0_fast_model(): assert actions.shape == (batch_size, 256) +def test_pi0_fast_lora_model(): + key = jax.random.key(0) + config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora") + model = config.create(key) + + batch_size = 2 + obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) + + loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) + assert loss.shape == (batch_size,) + + actions = nnx_utils.module_jit(model.sample_actions)(key, obs) + assert actions.shape == (batch_size, 256) + + lora_filter = nnx_utils.PathRegex(".*lora.*") + model_state = nnx.state(model) + + lora_state_elems = list(model_state.filter(lora_filter)) + assert len(lora_state_elems) > 0 + + @pytest.mark.manual def test_model_restore(): key = jax.random.key(0) diff --git a/src/openpi/models/pi0_fast.py b/src/openpi/models/pi0_fast.py index e1c57b0..fc74cf6 100644 --- a/src/openpi/models/pi0_fast.py +++ b/src/openpi/models/pi0_fast.py @@ -12,6 +12,7 @@ from openpi.models import model as _model import openpi.models.gemma_fast as _gemma import openpi.models.siglip as _siglip from openpi.shared import array_typing as at +import openpi.shared.nnx_utils as nnx_utils logger = logging.getLogger("openpi") @@ -117,6 +118,12 @@ class Pi0FASTConfig(_model.BaseModelConfig): return observation_spec, action_spec + def get_freeze_filter(self) -> nnx.filterlib.Filter: + """Returns the freeze filter based on the model config.""" + if "lora" in self.paligemma_variant: + return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*"))) + return nnx.Nothing + class Pi0FAST(_model.BaseModel): def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs): diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 84973e0..5b93cee 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -485,6 +485,23 @@ _CONFIGS = [ weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"), num_train_steps=30_000, ), + TrainConfig( + name="pi0_fast_libero_low_mem_finetune", + model=pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora"), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig( + local_files_only=False, # Set to True for local-only datasets. + prompt_from_task=True, + ), + ), + weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"), + num_train_steps=30_000, + freeze_filter=pi0_fast.Pi0FASTConfig( + action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" + ).get_freeze_filter(), + ema_decay=None, + ), # # Fine-tuning Aloha configs. #