added lora fast model support (#274)
* added lora fast model support * small config mistake * change to ConfigDict types instead of Any as suggested in PR #274 discussion https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945632119 * Simplify get_freeze_filter as per comment on PR #274 https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * actually pass the configs haha https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * update test to check if lora params are present https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808 * Fixed test to use nnx filters so that it is more clean * run formatter
This commit is contained in:
@@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif
|
|||||||
Used for FAST autoregressive policies.
|
Used for FAST autoregressive policies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
from typing import Literal, TypeAlias
|
from typing import Literal, TypeAlias
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
@@ -25,9 +26,10 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import ml_collections
|
import ml_collections
|
||||||
|
|
||||||
|
import openpi.models.lora as lora
|
||||||
import openpi.shared.array_typing as at
|
import openpi.shared.array_typing as at
|
||||||
|
|
||||||
Variant = Literal["gemma_2b"]
|
Variant = Literal["gemma_2b", "gemma_2b_lora"]
|
||||||
|
|
||||||
|
|
||||||
def get_config(variant):
|
def get_config(variant):
|
||||||
@@ -48,6 +50,26 @@ def get_config(variant):
|
|||||||
"remat_policy": "nothing_saveable",
|
"remat_policy": "nothing_saveable",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if variant == "gemma_2b_lora":
|
||||||
|
return ml_collections.ConfigDict(
|
||||||
|
{
|
||||||
|
"variant": variant,
|
||||||
|
"width": 2048,
|
||||||
|
"depth": 18,
|
||||||
|
"mlp_dim": 16_384,
|
||||||
|
"num_heads": 8,
|
||||||
|
"num_kv_heads": 1,
|
||||||
|
"head_dim": 256,
|
||||||
|
"norm_eps": 1e-6,
|
||||||
|
"vocab_size": 257_152,
|
||||||
|
"scan": True,
|
||||||
|
"remat_policy": "nothing_saveable",
|
||||||
|
"lora_configs": {
|
||||||
|
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||||
|
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
raise ValueError(f"Unknown variant: {variant}")
|
raise ValueError(f"Unknown variant: {variant}")
|
||||||
|
|
||||||
|
|
||||||
@@ -110,21 +132,34 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
cache_dtype: str | None = None
|
cache_dtype: str | None = None
|
||||||
|
|
||||||
|
lora_config: lora.LoRAConfig | None = None
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
if self.num_kv_heads == self.num_heads:
|
if self.num_kv_heads == self.num_heads:
|
||||||
self.qkv_einsum = Einsum(
|
self.qkv_einsum = lora.Einsum(
|
||||||
shape=(3, self.num_heads, self.features, self.head_dim),
|
shape=(3, self.num_heads, self.features, self.head_dim),
|
||||||
|
name="qkv_einsum",
|
||||||
|
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||||
|
lora_config=self.lora_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# MQA
|
self.q_einsum = lora.Einsum(
|
||||||
self.q_einsum = Einsum(
|
|
||||||
shape=(self.num_heads, self.features, self.head_dim),
|
shape=(self.num_heads, self.features, self.head_dim),
|
||||||
|
name="q_einsum",
|
||||||
|
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||||
|
lora_config=self.lora_config,
|
||||||
)
|
)
|
||||||
self.kv_einsum = Einsum(
|
self.kv_einsum = lora.Einsum(
|
||||||
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
||||||
|
name="kv_einsum",
|
||||||
|
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||||
|
lora_config=self.lora_config,
|
||||||
)
|
)
|
||||||
self.attn_vec_einsum = Einsum(
|
self.attn_vec_einsum = lora.Einsum(
|
||||||
shape=(self.num_heads, self.head_dim, self.features),
|
shape=(self.num_heads, self.head_dim, self.features),
|
||||||
|
name="attn_vec_einsum",
|
||||||
|
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||||
|
lora_config=self.lora_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_cache(self, k, v, cache_size):
|
def _init_cache(self, k, v, cache_size):
|
||||||
@@ -189,37 +224,6 @@ class Attention(nn.Module):
|
|||||||
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
||||||
|
|
||||||
|
|
||||||
@at.typecheck
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
"""Feed forward module."""
|
|
||||||
|
|
||||||
features: int
|
|
||||||
hidden_dim: int
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(self, x):
|
|
||||||
dtype = x.dtype # original dtype, could be half-precision
|
|
||||||
w_gating = self.param(
|
|
||||||
"gating_einsum",
|
|
||||||
nn.initializers.zeros_init(),
|
|
||||||
((2, self.features, self.hidden_dim)),
|
|
||||||
).astype(dtype)
|
|
||||||
ff_gate = jnp.dot(x, w_gating[0])
|
|
||||||
gate_value = nn.gelu(ff_gate)
|
|
||||||
|
|
||||||
ff1 = jnp.dot(x, w_gating[1])
|
|
||||||
activations = gate_value * ff1
|
|
||||||
|
|
||||||
w_linear = self.param(
|
|
||||||
"linear",
|
|
||||||
nn.initializers.zeros_init(),
|
|
||||||
(self.hidden_dim, self.features),
|
|
||||||
).astype(dtype)
|
|
||||||
outputs = jnp.dot(activations, w_linear)
|
|
||||||
assert outputs.dtype == dtype
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
@at.typecheck
|
@at.typecheck
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
"""Transformer block."""
|
"""Transformer block."""
|
||||||
@@ -233,6 +237,7 @@ class Block(nn.Module):
|
|||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
dropout_bdims: tuple[int, ...] = ()
|
dropout_bdims: tuple[int, ...] = ()
|
||||||
cache_dtype: str | None = None
|
cache_dtype: str | None = None
|
||||||
|
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.pre_attention_norm = RMSNorm()
|
self.pre_attention_norm = RMSNorm()
|
||||||
@@ -242,9 +247,12 @@ class Block(nn.Module):
|
|||||||
features=self.embed_dim,
|
features=self.embed_dim,
|
||||||
head_dim=self.head_dim,
|
head_dim=self.head_dim,
|
||||||
cache_dtype=self.cache_dtype,
|
cache_dtype=self.cache_dtype,
|
||||||
|
lora_config=self.lora_configs.get("attn"),
|
||||||
)
|
)
|
||||||
self.pre_ffw_norm = RMSNorm()
|
self.pre_ffw_norm = RMSNorm()
|
||||||
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
|
self.mlp = lora.FeedForward(
|
||||||
|
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
|
||||||
|
)
|
||||||
if self.dropout:
|
if self.dropout:
|
||||||
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
||||||
else:
|
else:
|
||||||
@@ -289,6 +297,7 @@ class Module(nn.Module):
|
|||||||
|
|
||||||
scan: bool = False
|
scan: bool = False
|
||||||
remat_policy: str = "none"
|
remat_policy: str = "none"
|
||||||
|
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -380,6 +389,7 @@ class Module(nn.Module):
|
|||||||
"dropout": self.dropout,
|
"dropout": self.dropout,
|
||||||
"dropout_bdims": self.dropout_bdims,
|
"dropout_bdims": self.dropout_bdims,
|
||||||
"cache_dtype": self.cache_dtype,
|
"cache_dtype": self.cache_dtype,
|
||||||
|
"lora_configs": self.lora_configs,
|
||||||
}
|
}
|
||||||
layers = self.scope.push("layers")
|
layers = self.scope.push("layers")
|
||||||
blocks = [
|
blocks = [
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from flax import nnx
|
||||||
import jax
|
import jax
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -53,6 +54,27 @@ def test_pi0_fast_model():
|
|||||||
assert actions.shape == (batch_size, 256)
|
assert actions.shape == (batch_size, 256)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pi0_fast_lora_model():
|
||||||
|
key = jax.random.key(0)
|
||||||
|
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
|
||||||
|
model = config.create(key)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||||
|
|
||||||
|
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||||
|
assert loss.shape == (batch_size,)
|
||||||
|
|
||||||
|
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||||
|
assert actions.shape == (batch_size, 256)
|
||||||
|
|
||||||
|
lora_filter = nnx_utils.PathRegex(".*lora.*")
|
||||||
|
model_state = nnx.state(model)
|
||||||
|
|
||||||
|
lora_state_elems = list(model_state.filter(lora_filter))
|
||||||
|
assert len(lora_state_elems) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.manual
|
@pytest.mark.manual
|
||||||
def test_model_restore():
|
def test_model_restore():
|
||||||
key = jax.random.key(0)
|
key = jax.random.key(0)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from openpi.models import model as _model
|
|||||||
import openpi.models.gemma_fast as _gemma
|
import openpi.models.gemma_fast as _gemma
|
||||||
import openpi.models.siglip as _siglip
|
import openpi.models.siglip as _siglip
|
||||||
from openpi.shared import array_typing as at
|
from openpi.shared import array_typing as at
|
||||||
|
import openpi.shared.nnx_utils as nnx_utils
|
||||||
|
|
||||||
logger = logging.getLogger("openpi")
|
logger = logging.getLogger("openpi")
|
||||||
|
|
||||||
@@ -117,6 +118,12 @@ class Pi0FASTConfig(_model.BaseModelConfig):
|
|||||||
|
|
||||||
return observation_spec, action_spec
|
return observation_spec, action_spec
|
||||||
|
|
||||||
|
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
||||||
|
"""Returns the freeze filter based on the model config."""
|
||||||
|
if "lora" in self.paligemma_variant:
|
||||||
|
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
|
||||||
|
return nnx.Nothing
|
||||||
|
|
||||||
|
|
||||||
class Pi0FAST(_model.BaseModel):
|
class Pi0FAST(_model.BaseModel):
|
||||||
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
|
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
|
||||||
|
|||||||
@@ -485,6 +485,23 @@ _CONFIGS = [
|
|||||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
|
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
|
||||||
num_train_steps=30_000,
|
num_train_steps=30_000,
|
||||||
),
|
),
|
||||||
|
TrainConfig(
|
||||||
|
name="pi0_fast_libero_low_mem_finetune",
|
||||||
|
model=pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora"),
|
||||||
|
data=LeRobotLiberoDataConfig(
|
||||||
|
repo_id="physical-intelligence/libero",
|
||||||
|
base_config=DataConfig(
|
||||||
|
local_files_only=False, # Set to True for local-only datasets.
|
||||||
|
prompt_from_task=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
|
||||||
|
num_train_steps=30_000,
|
||||||
|
freeze_filter=pi0_fast.Pi0FASTConfig(
|
||||||
|
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
|
||||||
|
).get_freeze_filter(),
|
||||||
|
ema_decay=None,
|
||||||
|
),
|
||||||
#
|
#
|
||||||
# Fine-tuning Aloha configs.
|
# Fine-tuning Aloha configs.
|
||||||
#
|
#
|
||||||
|
|||||||
Reference in New Issue
Block a user