From be46bdea8fc9d6ef720bed7e00c191feae9ca34b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 23 Oct 2025 13:50:30 +0200 Subject: [PATCH] feat(policies): add Nvidia Gr00t N1.5 model (#2292) * feat(policies): add Nvidia Gr00t N1.5 model Co-authored-by: lbenhorin Co-authored-by: Aravindh Co-authored-by: nv-sachdevkartik Co-authored-by: youliangt Co-authored-by: Michel Aractingi Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Jade Choghari * fix(docs): add groot to index Co-authored-by: sachdevkartik --------- Co-authored-by: lbenhorin Co-authored-by: Aravindh Co-authored-by: nv-sachdevkartik Co-authored-by: youliangt Co-authored-by: Michel Aractingi Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Jade Choghari Co-authored-by: sachdevkartik --- .github/workflows/full_tests.yml | 2 +- docs/source/_toctree.yml | 2 + docs/source/groot.mdx | 122 ++++ docs/source/policy_groot_README.md | 27 + pyproject.toml | 13 + src/lerobot/policies/__init__.py | 2 + src/lerobot/policies/factory.py | 36 + src/lerobot/policies/groot/README.md | 1 + src/lerobot/policies/groot/__init__.py | 21 + .../policies/groot/action_head/__init__.py | 14 + .../groot/action_head/action_encoder.py | 54 ++ .../groot/action_head/cross_attention_dit.py | 370 ++++++++++ .../action_head/flow_matching_action_head.py | 406 +++++++++++ .../policies/groot/configuration_groot.py | 201 ++++++ .../configuration_eagle2_5_vl.py | 135 ++++ .../image_processing_eagle2_5_vl_fast.py | 504 +++++++++++++ .../eagle2_hg_model/modeling_eagle2_5_vl.py | 395 +++++++++++ .../eagle2_hg_model/processing_eagle2_5_vl.py | 518 ++++++++++++++ src/lerobot/policies/groot/groot_n1.py | 376 ++++++++++ src/lerobot/policies/groot/modeling_groot.py | 198 ++++++ src/lerobot/policies/groot/processor_groot.py | 664 ++++++++++++++++++ src/lerobot/policies/groot/utils.py | 47 ++ src/lerobot/scripts/lerobot_eval.py | 13 +- src/lerobot/utils/import_utils.py | 1 + tests/policies/groot/test_groot_lerobot.py | 207 ++++++ .../policies/groot/test_groot_vs_original.py | 443 ++++++++++++ 26 files changed, 4766 insertions(+), 6 deletions(-) create mode 100644 docs/source/groot.mdx create mode 100644 docs/source/policy_groot_README.md create mode 120000 src/lerobot/policies/groot/README.md create mode 100644 src/lerobot/policies/groot/__init__.py create mode 100644 src/lerobot/policies/groot/action_head/__init__.py create mode 100644 src/lerobot/policies/groot/action_head/action_encoder.py create mode 100755 src/lerobot/policies/groot/action_head/cross_attention_dit.py create mode 100644 src/lerobot/policies/groot/action_head/flow_matching_action_head.py create mode 100644 src/lerobot/policies/groot/configuration_groot.py create mode 100755 src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py create mode 100644 src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py create mode 100755 src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py create mode 100755 src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py create mode 100644 src/lerobot/policies/groot/groot_n1.py create mode 100644 src/lerobot/policies/groot/modeling_groot.py create mode 100644 src/lerobot/policies/groot/processor_groot.py create mode 100644 src/lerobot/policies/groot/utils.py create mode 100644 tests/policies/groot/test_groot_lerobot.py create mode 100644 tests/policies/groot/test_groot_vs_original.py diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index d16fe5e7..0155eec1 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -78,7 +78,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install lerobot with all extras - run: uv sync --all-extras + run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional - name: Run pytest (all extras) run: uv run pytest tests -vv --maxfail=10 diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 5e100013..16fdb5e7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -37,6 +37,8 @@ title: π₀ (Pi0) - local: pi05 title: π₀.₅ (Pi05) + - local: groot + title: Nvidia Gr00t N1.5 title: "Policies" - sections: - local: il_sim diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx new file mode 100644 index 00000000..07348bea --- /dev/null +++ b/docs/source/groot.mdx @@ -0,0 +1,122 @@ +# Gr00t N1.5 Policy + +Gr00t N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments. + +This document outlines the specifics of its integration and usage within the LeRobot framework. + +## Model Overview + +NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots. + +Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks. + +GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception. + +Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes: + +- Real captured data from robots. +- Synthetic data generated using NVIDIA Isaac GR00T Blueprint. +- Internet-scale video data. + +This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments. + +## Installation Requirements + +As of today, Gr00t N1.5 requires flash attention for it's internal working. + +We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices. + +1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step. +2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running: + +```bash +# Check https://pytorch.org/get-started/locally/ for your system +pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX +pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies +pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation +python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')" +``` + +3. Install LeRobot by running: + +```bash +pip install lerobot[groot] # consider also installing libero,dev and test tags +``` + +## Usage + +To use Gr00t in your LeRobot configuration, specify the policy type as: + +```python +policy.type=groot +``` + +## Training + +### Training Command Example + +Here's a complete training command for finetuning the base Gr00t model on your own dataset: + +```bash +# Using a multi-GPU setup +accelerate launch \ + --multi_gpu \ + --num_processes=$NUM_GPUS \ + $(which lerobot-train) \ + --output_dir=$OUTPUT_DIR \ + --save_checkpoint=true \ + --batch_size=$BATCH_SIZE \ + --steps=$NUM_STEPS \ + --save_freq=$SAVE_FREQ \ + --log_freq=$LOG_FREQ \ + --policy.push_to_hub=true \ + --policy.type=groot \ + --policy.repo_id=$REPO_ID \ + --policy.tune_diffusion_model=false \ + --dataset.repo_id=$DATASET_ID \ + --wandb.enable=true \ + --wandb.disable_artifact=true \ + --job_name=$JOB_NAME +``` + +## Performance Results + +### Libero Benchmark Results + +Gr00t has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the Gr00t N1.5 model for 20k steps on the Libero dataset and compared the results to the Gr00t reference results. + +| Benchmark | LeRobot Implementation | Gr00t Reference | +| ------------------ | ---------------------- | --------------- | +| **Libero Spatial** | 82% | 92.0% | +| **Libero Object** | 99% | 92.0% | +| **Libero Long** | 72% | 76.0% | +| **Average** | 84% | 87.0% | + +These results demonstrate Gr00t's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section. + +### Evaluate in your hardware setup + +Once you have trained your model using your parameters you can run inference in your downstream task. Follow our by following for the downstream hardware task, you can follow our instructions in [Imitation Learning for Robots](./il_robots). For example: + +```bash +lerobot-record \ + --robot.type=bi_so100_follower \ + --robot.left_arm_port=/dev/ttyACM1 \ + --robot.right_arm_port=/dev/ttyACM0 \ + --robot.id=bimanual_follower \ + --robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}, + left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30}, + top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30}, + }' \ + --display_data=true \ + --dataset.repo_id=/eval_groot-bimanual \ + --dataset.num_episodes=10 \ + --dataset.single_task="Grab and handover the red cube to the other arm" + --policy.path=/groot-bimanual # your trained model + --dataset.episode_time_s=30 + --dataset.reset_time_s=10 +``` + +## License + +This model follows the **Apache 2.0 License**, consistent with the original [Gr00t repository](https://github.com/NVIDIA/Isaac-GR00T). diff --git a/docs/source/policy_groot_README.md b/docs/source/policy_groot_README.md new file mode 100644 index 00000000..efcd76eb --- /dev/null +++ b/docs/source/policy_groot_README.md @@ -0,0 +1,27 @@ +## Research Paper + +Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/ + +## Repository + +Code: https://github.com/NVIDIA/Isaac-GR00T + +## Citation + +```bibtex +@inproceedings{gr00tn1_2025, + archivePrefix = {arxiv}, + eprint = {2503.14734}, + title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots}, + author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu}, + month = {March}, + year = {2025}, + booktitle = {ArXiv Preprint}, +} +``` + +## Additional Resources + +Blog: https://developer.nvidia.com/isaac/gr00t + +Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B diff --git a/pyproject.toml b/pyproject.toml index 1c71acec..b76593b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,17 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"] # Policies pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] +groot = [ + "lerobot[transformers-dep]", + "peft>=0.13.0,<1.0.0", + "dm-tree>=0.1.8,<1.0.0", + "timm>=1.0.0,<1.1.0", + "safetensors>=0.4.3,<1.0.0", + "Pillow>=10.0.0,<13.0.0", + "decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')", + "ninja>=1.11.1,<2.0.0", + "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" +] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features @@ -145,6 +156,7 @@ all = [ "lerobot[intelrealsense]", "lerobot[pi]", "lerobot[smolvla]", + # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn "lerobot[hilserl]", "lerobot[async]", "lerobot[dev]", @@ -243,6 +255,7 @@ default.extend-ignore-identifiers-re = [ "pn", "ser", "ein", + "thw", "inpt", ] diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 49f1e0f9..4cdc89ea 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -14,6 +14,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig +from .groot.configuration_groot import GrootConfig as GrootConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi05.configuration_pi05 import PI05Config as PI05Config from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig @@ -29,4 +30,5 @@ __all__ = [ "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", + "GrootConfig", ] diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 6e524f2a..bdad5cbb 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy @@ -101,6 +102,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "groot": + from lerobot.policies.groot.modeling_groot import GrootPolicy + + return GrootPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -142,6 +147,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SmolVLAConfig(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) + elif policy_type == "groot": + return GrootConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") @@ -199,6 +206,27 @@ def make_pre_post_processors( policy configuration type. """ if pretrained_path: + # TODO(Steven): Temporary patch, implement correctly the processors for Gr00t + if isinstance(policy_cfg, GrootConfig): + # GROOT handles normalization in groot_pack_inputs_v3 step + # Need to override both stats AND normalize_min_max since saved config might be empty + preprocessor_overrides = {} + postprocessor_overrides = {} + preprocessor_overrides["groot_pack_inputs_v3"] = { + "stats": kwargs.get("dataset_stats"), + "normalize_min_max": True, + } + + # Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats + env_action_dim = policy_cfg.output_features["action"].shape[0] + postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = { + "stats": kwargs.get("dataset_stats"), + "normalize_min_max": True, + "env_action_dim": env_action_dim, + } + kwargs["preprocessor_overrides"] = preprocessor_overrides + kwargs["postprocessor_overrides"] = postprocessor_overrides + return ( PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, @@ -293,6 +321,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, GrootConfig): + from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors + + processors = make_groot_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + else: raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") diff --git a/src/lerobot/policies/groot/README.md b/src/lerobot/policies/groot/README.md new file mode 120000 index 00000000..ff4937f5 --- /dev/null +++ b/src/lerobot/policies/groot/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_groot_README.md \ No newline at end of file diff --git a/src/lerobot/policies/groot/__init__.py b/src/lerobot/policies/groot/__init__.py new file mode 100644 index 00000000..c8933ff5 --- /dev/null +++ b/src/lerobot/policies/groot/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Nvidia and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_groot import GrootConfig +from .modeling_groot import GrootPolicy +from .processor_groot import make_groot_pre_post_processors + +__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"] diff --git a/src/lerobot/policies/groot/action_head/__init__.py b/src/lerobot/policies/groot/action_head/__init__.py new file mode 100644 index 00000000..3159bfe6 --- /dev/null +++ b/src/lerobot/policies/groot/action_head/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/policies/groot/action_head/action_encoder.py b/src/lerobot/policies/groot/action_head/action_encoder.py new file mode 100644 index 00000000..c6fa0a77 --- /dev/null +++ b/src/lerobot/policies/groot/action_head/action_encoder.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def swish(x): + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + """ + Produces a sinusoidal encoding of shape (B, T, w) + given timesteps of shape (B, T). + """ + + def __init__(self, embedding_dim): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps): + # timesteps: shape (B, T) + # We'll compute sin/cos frequencies across dim T + timesteps = timesteps.float() # ensure float + + b, t = timesteps.shape + device = timesteps.device + + half_dim = self.embedding_dim // 2 + # typical log space frequencies for sinusoidal encoding + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * ( + torch.log(torch.tensor(10000.0)) / half_dim + ) + # Expand timesteps to (B, T, 1) then multiply + freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim) + + sin = torch.sin(freqs) + cos = torch.cos(freqs) + enc = torch.cat([sin, cos], dim=-1) # (B, T, w) + + return enc diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py new file mode 100755 index 00000000..40f7ba60 --- /dev/null +++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py @@ -0,0 +1,370 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F # noqa: N812 +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.embeddings import ( + SinusoidalPositionalEmbedding, + TimestepEmbedding, + Timesteps, +) +from torch import nn + + +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim, compute_dtype=torch.float32): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps): + dtype = next(self.parameters()).dtype + timesteps_proj = self.time_proj(timesteps).to(dtype) + timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) + return timesteps_emb + + +class AdaLayerNorm(nn.Module): + def __init__( + self, + embedding_dim: int, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + self.chunk_dim = chunk_dim + output_dim = embedding_dim * 2 + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + temb = self.linear(self.silu(temb)) + scale, shift = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] + return x + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: str | None = None, + num_positional_embeddings: int | None = None, + ff_inner_dim: int | None = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.norm_type = norm_type + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + if final_dropout: + self.final_dropout = nn.Dropout(dropout) + else: + self.final_dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + temb: torch.LongTensor | None = None, + ) -> torch.Tensor: + # 0. Self-Attention + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, temb) + else: + norm_hidden_states = self.norm1(hidden_states) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + # encoder_attention_mask=encoder_attention_mask, + ) + if self.final_dropout: + attn_output = self.final_dropout(attn_output) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + return hidden_states + + +class DiT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + cross_attention_dim: int | None = None, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + # Timestep encoder + self.timestep_encoder = TimestepEncoder( + embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype + ) + + all_blocks = [] + for idx in range(self.config.num_layers): + use_self_attn = idx % 2 == 1 and interleave_self_attention + curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None + + all_blocks += [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + cross_attention_dim=curr_cross_attention_dim, + ) + ] + self.transformer_blocks = nn.ModuleList(all_blocks) + + # Output blocks + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim) + print( + "Total number of DiT parameters: ", + sum(p.numel() for p in self.parameters() if p.requires_grad), + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + encoder_hidden_states: torch.Tensor, # Shape: (B, S, D) + timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_all_hidden_states: bool = False, + ): + # Encode timesteps + temb = self.timestep_encoder(timestep) + + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1 and self.config.interleave_self_attention: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + temb=temb, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + # Output processing + conditioning = temb + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + else: + return self.proj_out_2(hidden_states) + + +class SelfAttentionTransformer(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 8, + attention_head_dim: int = 64, + output_dim: int = 26, + num_layers: int = 12, + dropout: float = 0.1, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: int | None = 1000, + upcast_attention: bool = False, + max_num_positional_embeddings: int = 512, + compute_dtype=torch.float32, + final_dropout: bool = True, + positional_embeddings: str | None = "sinusoidal", + interleave_self_attention=False, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + positional_embeddings=positional_embeddings, + num_positional_embeddings=self.config.max_num_positional_embeddings, + final_dropout=final_dropout, + ) + for _ in range(self.config.num_layers) + ] + ) + print( + "Total number of SelfAttentionTransformer parameters: ", + sum(p.numel() for p in self.parameters() if p.requires_grad), + ) + + def forward( + self, + hidden_states: torch.Tensor, # Shape: (B, T, D) + return_all_hidden_states: bool = False, + ): + # Process through transformer blocks - single pass through the blocks + hidden_states = hidden_states.contiguous() + all_hidden_states = [hidden_states] + + # Process through transformer blocks + for _idx, block in enumerate(self.transformer_blocks): + hidden_states = block(hidden_states) + all_hidden_states.append(hidden_states) + + if return_all_hidden_states: + return hidden_states, all_hidden_states + else: + return hidden_states diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py new file mode 100644 index 00000000..bfc456ba --- /dev/null +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import nn +from torch.distributions import Beta + +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers import PretrainedConfig + from transformers.feature_extraction_utils import BatchFeature +else: + PretrainedConfig = object + BatchFeature = None + +from lerobot.policies.groot.action_head.action_encoder import ( + SinusoidalPositionalEncoding, + swish, +) + +from .cross_attention_dit import DiT, SelfAttentionTransformer + + +class CategorySpecificLinear(nn.Module): + def __init__(self, num_categories, input_dim, hidden_dim): + super().__init__() + self.num_categories = num_categories + # For each category, we have separate weights and biases. + self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) + self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) + + def forward(self, x, cat_ids): + selected_w = self.W[cat_ids] + selected_b = self.b[cat_ids] + return torch.bmm(x, selected_w) + selected_b.unsqueeze(1) + + +class CategorySpecificMLP(nn.Module): + def __init__(self, num_categories, input_dim, hidden_dim, output_dim): + super().__init__() + self.num_categories = num_categories + self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) + self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) + + def forward(self, x, cat_ids): + hidden = F.relu(self.layer1(x, cat_ids)) + return self.layer2(hidden, cat_ids) + + +class MultiEmbodimentActionEncoder(nn.Module): + def __init__(self, action_dim, hidden_size, num_embodiments): + super().__init__() + self.hidden_size = hidden_size + self.num_embodiments = num_embodiments + + # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w} + self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w) + self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w) + self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions, timesteps, cat_ids): + """ + actions: shape (B, T, action_dim) + timesteps: shape (B,) -- a single scalar per batch item + cat_ids: shape (B,) + returns: shape (B, T, hidden_size) + """ + b, t, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # e.g. if timesteps is (B,), replicate across T + if timesteps.dim() == 1 and timesteps.shape[0] == b: + # shape (B,) => (B,T) + timesteps = timesteps.unsqueeze(1).expand(-1, t) + else: + raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.") + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions, cat_ids) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = swish(self.W2(x, cat_ids)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x, cat_ids) + return x + + +@dataclass +class FlowmatchingActionHeadConfig(PretrainedConfig): + """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head""" + + add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"}) + model_dtype: str = field(default="float32", metadata={"help": "Model data type."}) + diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."}) + input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."}) + backbone_embedding_dim: int = field( + default=1536, metadata={"help": "Backbone embedding channel dimension."} + ) + + hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."}) + max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"}) + action_dim: int = field(default=None, metadata={"help": "Action dimension."}) + action_horizon: int = field(default=None, metadata={"help": "Action horizon."}) + noise_beta_alpha: float = field(default=1.5, metadata={"help": ""}) + noise_beta_beta: float = field(default=1.0, metadata={"help": ""}) + noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."}) + num_timestep_buckets: int = field( + default=1000, metadata={"help": "Number of timestep discretization buckets."} + ) + num_inference_timesteps: int = field( + default=None, + metadata={"help": "Number of inference steps for noise diffusion."}, + ) + max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."}) + tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."}) + tune_diffusion_model: bool = field( + default=True, metadata={"help": "Whether to tune the diffusion model."} + ) + load_pretrained_det_decode_layer_path: str = field( + default=None, metadata={"help": "Path to pretrained detection model."} + ) + detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."}) + + freeze_decode_layer: bool = field(default=False) + expand_batch: int = field(default=None) + use_vlln: bool = field(default=True) + + vl_self_attention_cfg: dict = field(default=None) + num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."}) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for key, value in kwargs.items(): + setattr(self, key, value) + + +class FlowmatchingActionHead(nn.Module): + config_class = FlowmatchingActionHeadConfig + supports_gradient_checkpointing = True + + def __init__( + self, + config: FlowmatchingActionHeadConfig, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.input_embedding_dim = config.input_embedding_dim + + self.model = DiT(**config.diffusion_model_cfg) + self.action_dim = config.action_dim + self.action_horizon = config.action_horizon + self.num_inference_timesteps = config.num_inference_timesteps + + self.state_encoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=config.max_state_dim, + hidden_dim=self.hidden_size, + output_dim=self.input_embedding_dim, + ) + self.action_encoder = MultiEmbodimentActionEncoder( + action_dim=config.action_dim, + hidden_size=self.input_embedding_dim, + num_embodiments=config.max_num_embodiments, + ) + self.action_decoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=self.hidden_size, + hidden_dim=self.hidden_size, + output_dim=self.action_dim, + ) + self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim) + nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02) + + self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity() + self.vl_self_attention = ( + SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity() + ) + + if config.add_pos_embed: + self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) + nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) + + self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta) + self.num_timestep_buckets = config.num_timestep_buckets + self.config = config + self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model) + + def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool): + self.tune_projector = tune_projector + self.tune_diffusion_model = tune_diffusion_model + for p in self.parameters(): + p.requires_grad = True + if not tune_projector: + self.state_encoder.requires_grad_(False) + self.action_encoder.requires_grad_(False) + self.action_decoder.requires_grad_(False) + if self.config.add_pos_embed: + self.position_embedding.requires_grad_(False) + if not tune_diffusion_model: + self.model.requires_grad_(False) + print(f"Tune action head projector: {self.tune_projector}") + print(f"Tune action head diffusion model: {self.tune_diffusion_model}") + # Check if any parameters are still trainable. If not, print a warning. + if not tune_projector and not tune_diffusion_model: + for name, p in self.named_parameters(): + if p.requires_grad: + print(f"Action head trainable parameter: {name}") + if not any(p.requires_grad for p in self.parameters()): + print("Warning: No action head trainable parameters found.") + + def set_frozen_modules_to_eval_mode(self): + """ + Huggingface will call model.train() at each training_step. To ensure + the expected behaviors for modules like dropout, batchnorm, etc., we + need to call model.eval() for the frozen modules. + """ + if self.training: + if not self.tune_projector: + self.state_encoder.eval() + self.action_encoder.eval() + self.action_decoder.eval() + if self.config.add_pos_embed: + self.position_embedding.eval() + if not self.tune_diffusion_model: + self.model.eval() + + def sample_time(self, batch_size, device, dtype): + sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) + return (self.config.noise_s - sample) / self.config.noise_s + + def prepare_input(self, batch: dict) -> BatchFeature: + return BatchFeature(data=batch) + + def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature: + backbone_features = backbone_output["backbone_features"] + backbone_features = self.vlln(backbone_features) + backbone_features = self.vl_self_attention(backbone_features) + backbone_output["backbone_features"] = backbone_features + return backbone_output + + def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: + # Set frozen modules to eval + self.set_frozen_modules_to_eval_mode() + + backbone_output = self.process_backbone_output(backbone_output) + + if self.config.expand_batch is not None: + for k, v in backbone_output.items(): + ndim = len(v.shape) + factors = [self.config.expand_batch] + while len(factors) < ndim: + factors.append(1) + factors = tuple(factors) + expanded = v.repeat(*factors) + backbone_output[k] = expanded + + for k, v in action_input.items(): + ndim = len(v.shape) + factors = [self.config.expand_batch] + while len(factors) < ndim: + factors.append(1) + factors = tuple(factors) + expanded = v.repeat(*factors) + action_input[k] = expanded + + # Get vision and language embeddings. + vl_embs = backbone_output.backbone_features + device = vl_embs.device + + # Get embodiment ID. + embodiment_id = action_input.embodiment_id + + # Embed state. + state_features = self.state_encoder(action_input.state, embodiment_id) + + # Embed noised action trajectory. + actions = action_input.action + noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) + t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype) + t = t[:, None, None] # shape (B,1,1) for broadcast + + noisy_trajectory = (1 - t) * noise + t * actions + velocity = actions - noise + + # Convert (continuous) t -> discrete if needed + t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long() + action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id) + + # Maybe add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # Join vision, language, state and action embedding along sequence dimension. + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1) + sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) + + vl_attn_mask = backbone_output.backbone_attention_mask + + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embs, + encoder_attention_mask=vl_attn_mask, + timestep=t_discretized, + return_all_hidden_states=False, # NOTE (YL): not using flare now + ) + pred = self.action_decoder(model_output, embodiment_id) + pred_actions = pred[:, -actions.shape[1] :] + + # Slice out only the action portion of pred and target. + action_mask = action_input.action_mask + loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask + loss = loss.sum() / action_mask.sum() + output_dict = { + "loss": loss, + } + return BatchFeature(data=output_dict) + + @torch.no_grad() + def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: + backbone_output = self.process_backbone_output(backbone_output) + + # Get vision and language embeddings. + vl_embs = backbone_output.backbone_features + embodiment_id = action_input.embodiment_id + + # Embed state. + state_features = self.state_encoder(action_input.state, embodiment_id) + + # Set initial actions as the sampled noise. + batch_size = vl_embs.shape[0] + device = vl_embs.device + actions = torch.randn( + size=(batch_size, self.config.action_horizon, self.config.action_dim), + dtype=vl_embs.dtype, + device=device, + ) + + num_steps = self.num_inference_timesteps + dt = 1.0 / num_steps + + # Run denoising steps. + for t in range(num_steps): + t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ... + t_discretized = int(t_cont * self.num_timestep_buckets) + + # Embed noised action trajectory. + timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device) + action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id) + # Maybe add position embedding. + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + pos_embs = self.position_embedding(pos_ids).unsqueeze(0) + action_features = action_features + pos_embs + + # Join vision, language, state and action embedding along sequence dimension. + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1) + sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) + + # Run model forward. + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embs, + timestep=timesteps_tensor, + ) + pred = self.action_decoder(model_output, embodiment_id) + + pred_velocity = pred[:, -self.action_horizon :] + + # Update actions using euler integration. + actions = actions + dt * pred_velocity + return BatchFeature(data={"action_pred": actions}) + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py new file mode 100644 index 00000000..8002c69e --- /dev/null +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python + +# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("groot") +@dataclass +class GrootConfig(PreTrainedConfig): + """Configuration for Groot policy wrapper.""" + + # Basic policy settings + n_obs_steps: int = 1 + chunk_size: int = 50 + n_action_steps: int = 50 + + # Dimension settings (must match pretrained GR00T model expectations) + # Maximum state dimension. Shorter states will be zero-padded. + max_state_dim: int = 64 + + # Maximum action dimension. Shorter actions will be zero-padded. + max_action_dim: int = 32 + + # Normalization (start with identity, adjust as needed) + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Image preprocessing (adjust to match Groot's expected input) + image_size: tuple[int, int] = (224, 224) + + # Groot-specific model parameters (from groot_finetune_script.py) + + # Path or HuggingFace model ID for the base Groot model + base_model_path: str = "nvidia/GR00T-N1.5-3B" + + # HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer. + tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5" + + # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') + embodiment_tag: str = "new_embodiment" + + # Fine-tuning control arguments + + # Whether to fine-tune the llm backbone + tune_llm: bool = False + + # Whether to fine-tune the vision tower + tune_visual: bool = False + + # Whether to fine-tune the projector + tune_projector: bool = True + + # Whether to fine-tune the diffusion model + tune_diffusion_model: bool = True + + # LoRA parameters (from groot_finetune_script.py) + # Rank for the LORA model. If 0, no LORA will be used. + lora_rank: int = 0 + + # Alpha value for the LORA model + lora_alpha: int = 16 + + # Dropout rate for the LORA model + lora_dropout: float = 0.1 + + # Whether to use the full model for LORA + lora_full_model: bool = False + + # Training parameters (matching groot_finetune_script.py) + optimizer_lr: float = 1e-4 + optimizer_betas: tuple[float, float] = (0.95, 0.999) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-5 + warmup_ratio: float = 0.05 + use_bf16: bool = True + + # Dataset parameters + # Video backend to use for training ('decord' or 'torchvision_av') + video_backend: str = "decord" + + # Whether to balance dataset weights in mixture datasets + balance_dataset_weights: bool = True + + # Whether to sample trajectories weighted by their length + balance_trajectory_weights: bool = True + + # Optional dataset paths for delegating training to Isaac-GR00T runner + dataset_paths: list[str] | None = None + output_dir: str = "./tmp/gr00t" + save_steps: int = 1000 + max_steps: int = 10000 + batch_size: int = 32 + dataloader_num_workers: int = 8 + report_to: str = "wandb" + resume: bool = False + + def __post_init__(self): + super().__post_init__() + + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})" + ) + + # groot_repo_path is now optional since we ported the components + # No validation needed + + def validate_features(self) -> None: + """Validate and set up input/output features for Groot.""" + image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL] + if not image_features: + raise ValueError( + "Groot policy requires at least one visual input feature. " + "No features of type FeatureType.VISUAL found in input_features." + ) + + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), + ) + self.input_features["observation.state"] = state_feature + else: + state_shape = self.input_features["observation.state"].shape + state_dim = state_shape[0] if state_shape else 0 + if state_dim > self.max_state_dim: + raise ValueError( + f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. " + f"Either reduce state dimension or increase max_state_dim in config." + ) + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), + ) + self.output_features["action"] = action_feature + else: + action_shape = self.output_features["action"].shape + action_dim = action_shape[0] if action_shape else 0 + if action_dim > self.max_action_dim: + raise ValueError( + f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. " + f"Either reduce action dimension or increase max_action_dim in config." + ) + + def get_optimizer_preset(self) -> AdamWConfig: + """Return optimizer configuration.""" + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: + """Return scheduler configuration.""" + return CosineDecayWithWarmupSchedulerConfig( + num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default + num_decay_steps=10000, # Adjust based on training steps + peak_lr=self.optimizer_lr, + decay_lr=self.optimizer_lr * 0.1, + ) + + @property + def observation_delta_indices(self) -> None: + """Return indices for delta observations (None for Groot).""" + return None + + @property + def action_delta_indices(self) -> list[int]: + """Return indices for delta actions.""" + return list(range(min(self.chunk_size, 16))) + + @property + def reward_delta_indices(self) -> None: + """Return indices for delta rewards (None for Groot).""" + return None diff --git a/src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py new file mode 100755 index 00000000..526b4f7a --- /dev/null +++ b/src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.siglip.configuration_siglip import SiglipVisionConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Eagle25VLConfig(PretrainedConfig): + model_type = "eagle_2_5_vl" + is_composition = True + sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config} + + def __init__( + self, + vision_config=None, + text_config=None, + use_backbone_lora=0, + use_llm_lora=0, + pad2square=False, + select_layer=-4, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + loss_version="v1", + min_dynamic_tiles=1, + max_dynamic_tiles=6, + mlp_checkpoint=False, + initializer_range=0.02, + _attn_implementation="flash_attention_2", + _attn_implementation_autoset=False, + llm_config=None, + image_token_index=None, + use_pixel_shuffle=True, + mlp_connector_layers=2, + **kwargs, + ): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {"model_type": "siglip_vision_model"} + logger.info("vision_config is None. Initializing the InternVisionConfig with default values.") + + if text_config is None: + text_config = {"architectures": ["Qwen2ForCausalLM"]} + logger.info( + "text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)." + ) + + if vision_config["model_type"] == "siglip_vision_model": + self.vision_config = SiglipVisionConfig(**vision_config) + else: + raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"])) + + if text_config["architectures"][0] == "LlamaForCausalLM": + self.text_config = LlamaConfig(**text_config) + elif text_config["architectures"][0] == "Qwen2ForCausalLM": + self.text_config = Qwen2Config(**text_config) + elif text_config["architectures"][0] == "Qwen3ForCausalLM": + self.text_config = Qwen3Config(**text_config) + else: + raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0])) + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.mlp_checkpoint = mlp_checkpoint + self.pad2square = pad2square + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.loss_version = loss_version + self.initializer_range = initializer_range + self.min_dynamic_tiles = min_dynamic_tiles + self.max_dynamic_tiles = max_dynamic_tiles + self.tie_word_embeddings = self.text_config.tie_word_embeddings + self._attn_implementation = _attn_implementation + self._attn_implementation_autoset = _attn_implementation_autoset + self.image_token_index = image_token_index + self.use_pixel_shuffle = use_pixel_shuffle + self.mlp_connector_layers = mlp_connector_layers + logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}") + logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + output["model_type"] = self.__class__.model_type + output["use_backbone_lora"] = self.use_backbone_lora + output["use_llm_lora"] = self.use_llm_lora + output["pad2square"] = self.pad2square + output["select_layer"] = self.select_layer + output["force_image_size"] = self.force_image_size + output["downsample_ratio"] = self.downsample_ratio + output["template"] = self.template + output["dynamic_image_size"] = self.dynamic_image_size + output["use_thumbnail"] = self.use_thumbnail + output["min_dynamic_tiles"] = self.min_dynamic_tiles + output["max_dynamic_tiles"] = self.max_dynamic_tiles + output["tie_word_embeddings"] = self.tie_word_embeddings + output["_attn_implementation"] = self._attn_implementation + output["_attn_implementation_autoset"] = self._attn_implementation_autoset + output["use_pixel_shuffle"] = self.use_pixel_shuffle + output["mlp_connector_layers"] = self.mlp_connector_layers + return output diff --git a/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py new file mode 100644 index 00000000..6b4f6d7a --- /dev/null +++ b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py @@ -0,0 +1,504 @@ +# -------------------------------------------------------- +# NVIDIA +# Copyright (c) 2025 NVIDIA +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + + +# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +from typing import Optional + +from transformers.image_processing_utils import ( + BatchFeature, + get_patch_output_size, +) +from transformers.image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from transformers.image_utils import ( + IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5 + IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5 + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + get_image_size, + make_flat_list_of_images, + validate_kwargs, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_v2_available, +) +from transformers.video_utils import VideoInput + +if is_torch_available(): + import torch +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F # noqa: N812 + from transformers.image_utils import pil_torch_interpolation_mapping +else: + from torchvision.transforms import functional as F # noqa: N812 + + +def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor: + """Crop the given numpy array. + + Args: + img (torch.Tensor): Image to be cropped. Format should be (C, H, W). + left (int): The left coordinate of the crop box. + top (int): The top coordinate of the crop box. + right (int): The right coordinate of the crop box. + bottom (int): The bottom coordinate of the crop box. + + Returns: + torch.Tensor: Cropped image. + """ + if not isinstance(img, torch.Tensor): + raise TypeError(f"img should be torch.Tensor. Got {type(img)}") + + if img.ndim not in [2, 3]: + raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}") + + img_height = img.shape[1] + img_width = img.shape[2] + if top < 0 or left < 0 or bottom > img_height or right > img_width: + raise ValueError("Crop coordinates out of bounds") + + if top >= bottom or left >= right: + raise ValueError("Invalid crop coordinates") + + return img[:, top:bottom, left:right] + + +class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + max_dynamic_tiles: int | None + min_dynamic_tiles: int | None + use_thumbnail: bool | None + pad_during_tiling: bool | None + do_pad: bool | None + + +@add_start_docstrings( + "Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.", + # BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove! + """ + image_grid_pinpoints (`List[List[int]]`, *optional*): + A list of possible resolutions to use for processing high resolution images. The best resolution is selected + based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` + method. Not used for processing videos. + do_pad (`bool`, *optional*): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + """, +) +class Eagle25VLImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 448, "width": 448} + default_to_square = False + crop_size = None + do_resize = True + do_center_crop = None + do_rescale = True + do_normalize = True + do_convert_rgb = True + do_pad = True + max_dynamic_tiles = 12 + min_dynamic_tiles = 1 + use_thumbnail = True + pad_during_tiling = False + valid_kwargs = Eagle25VLFastImageProcessorKwargs + model_input_names = ["pixel_values_videos"] + + def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @add_start_docstrings( + # BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove! + """ + max_dynamic_tiles (`int`, *optional*): + The maximum number of dynamic tiles to use for processing high resolution images. + min_dynamic_tiles (`int`, *optional*): + The minimum number of dynamic tiles to use for processing high resolution images. + use_thumbnail (`bool`, *optional*): + Whether to use a thumbnail for processing high resolution images. + pad_during_tiling (`bool`, *optional*): + Whether to pad the image during tiling. + do_pad (`bool`, *optional*): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + """, + ) + + # NOTE(YL): we will overload the preprocess method to add the image_flags + # def preprocess( + # self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs] + # ) -> BatchFeature: + # return super().preprocess(images, **kwargs) + + def _prepare_images_structure( + self, + images: ImageInput, + expected_ndims: int = 3, + ) -> ImageInput: + """ + Prepare the images structure for processing. + + Args: + images (`ImageInput`): + The input images to process. + expected_ndims (`int`, *optional*, defaults to 3): + Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility). + + Returns: + `ImageInput`: The images with a valid nesting. + """ + return make_flat_list_of_images(images) + + def _resize_for_patching( + self, + image: "torch.Tensor", + target_resolution: tuple, + interpolation: "F.InterpolationMode", + input_data_format: ChannelDimension, + ) -> "torch.Tensor": + """ + Resizes an image to a target resolution while maintaining aspect ratio. + + Args: + image ("torch.Tensor"): + The input image. + target_resolution (tuple): + The target resolution (height, width) of the image. + interpolation (`InterpolationMode`): + Resampling filter to use if resizing the image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + "torch.Tensor": The resized and padded image. + """ + new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + + # Resize the image + resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) + + return resized_image + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + """ + previous version mainly focus on ratio. + We also consider area ratio here. + """ + best_factor = float("-inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + # ratio_diff = abs(aspect_ratio - target_aspect_ratio) + # area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area + """ + new area > 60% of original image area is enough. + """ + factor_based_on_area_n_ratio = min( + (ratio[0] * ratio[1] * image_size * image_size) / area, 0.6 + ) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio) + + if factor_based_on_area_n_ratio > best_factor: + best_factor = factor_based_on_area_n_ratio + best_ratio = ratio + + return best_ratio + + def _pad_for_patching( + self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension + ) -> "torch.Tensor": + """ + Pad an image to a target resolution while maintaining aspect ratio. + """ + target_height, target_width = target_resolution + new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) + + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + + padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y]) + + return padded_image + + def _get_image_patches( + self, + image: "torch.Tensor", + min_num: int, + max_num: int, + size: tuple, + tile_size: int, + use_thumbnail: bool, + interpolation: "F.InterpolationMode", + pad_during_tiling: bool, + ) -> list["torch.Tensor"]: + image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST) + orig_height, orig_width = image_size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, tile_size + ) + + # calculate the target width and height + target_width = tile_size * target_aspect_ratio[0] + target_height = tile_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + if pad_during_tiling: + resized_image = self._resize_for_patching( + image, + (target_height, target_width), + interpolation=interpolation, + input_data_format=ChannelDimension.FIRST, + ) + padded_image = self._pad_for_patching( + resized_image, + (target_height, target_width), + input_data_format=ChannelDimension.FIRST, + ) + image_used_to_split = padded_image + else: + image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation) + + processed_tiles = [] + for i in range(blocks): + box = ( + (i % (target_width // tile_size)) * tile_size, + (i // (target_width // tile_size)) * tile_size, + ((i % (target_width // tile_size)) + 1) * tile_size, + ((i // (target_width // tile_size)) + 1) * tile_size, + ) + # split the image + split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3]) + processed_tiles.append(split_img) + assert len(processed_tiles) == blocks + + if use_thumbnail and len(processed_tiles) != 1: + thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation) + processed_tiles.append(thumbnail_img) + + return processed_tiles + + def _pad_for_batching( + self, + pixel_values: list["torch.Tensor"], + ) -> list["torch.Tensor"]: + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + + Args: + pixel_values (`List[torch.Tensor]`): + An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`) + + Returns: + List[`torch.Tensor`]: The padded images. + """ + max_patch = max(len(x) for x in pixel_values) + pixel_values = [ + torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]]) + for image in pixel_values + ] + + return pixel_values + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + max_dynamic_tiles: int, + min_dynamic_tiles: int, + use_thumbnail: bool, + pad_during_tiling: bool, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + do_pad: bool, + return_tensors: str | TensorType | None, + pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility + disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility + ) -> BatchFeature: + processed_images = [] + image_sizes = [] + # Determine the size tuple + if size and size.height and size.width: + size_tuple = (size.height, size.width) + else: + size_tuple = (size.shortest_edge, size.shortest_edge) + + # Determine the patch size + if crop_size and crop_size.height: + tile_size = crop_size.height + elif size and size.height: + tile_size = size.height + else: + tile_size = size.shortest_edge + + for image in images: + image_patches = self._get_image_patches( + image, + min_num=min_dynamic_tiles, + max_num=max_dynamic_tiles, + size=size_tuple, + tile_size=tile_size, + use_thumbnail=use_thumbnail, + interpolation=interpolation, + pad_during_tiling=pad_during_tiling, + ) + + # Group images by size for batched processing + processed_image_patches_grouped = {} + # Added for transformers >=4.53.0 compatibility + grouped_image_patches, grouped_image_patches_index = group_images_by_shape( + image_patches, + disable_grouping=disable_grouping, + ) + + for shape, stacked_image_patches in grouped_image_patches.items(): + if do_resize: + stacked_image_patches = self.resize( + image=stacked_image_patches, + size=size, + interpolation=interpolation, + ) + if do_center_crop: + stacked_image_patches = self.center_crop(stacked_image_patches, crop_size) + # Fused rescale and normalize + stacked_image_patches = self.rescale_and_normalize( + stacked_image_patches, + do_rescale, + rescale_factor, + do_normalize, + image_mean, + image_std, + ) + processed_image_patches_grouped[shape] = stacked_image_patches + processed_image_patches = reorder_images( + processed_image_patches_grouped, grouped_image_patches_index + ) + processed_image_patches = ( + torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches + ) + processed_images.append(processed_image_patches) + image_sizes.append(get_image_size(image, ChannelDimension.FIRST)) + + if do_pad: + processed_images = self._pad_for_batching(processed_images) + + # processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature( + data={"pixel_values": processed_images, "image_sizes": image_sizes}, + tensor_type=return_tensors, + ) + + def preprocess( + self, + images: ImageInput, + videos: VideoInput = None, + **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs], + ) -> BatchFeature: + validate_kwargs( + captured_kwargs=kwargs.keys(), + valid_processor_keys=self.valid_kwargs.__annotations__.keys(), + ) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + # Prepare input images + # transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images + if images is not None: + images = self._prepare_image_like_inputs( + images=images, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + + if videos is not None: + videos = self._prepare_image_like_inputs( + images=videos, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample + # Added for transformers >=4.53.0 compatibility + resample = kwargs.pop("resample", self.resample) + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] + if isinstance(resample, PILImageResampling | int) + else resample + ) + + # Filter kwargs to only include those accepted by _preprocess + valid_preprocess_kwargs = { + "do_resize", + "size", + "max_dynamic_tiles", + "min_dynamic_tiles", + "use_thumbnail", + "pad_during_tiling", + "interpolation", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_pad", + "return_tensors", + "pad_size", + "disable_grouping", + } + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs} + if images is not None: + return self._preprocess(images, **filtered_kwargs) + elif videos is not None: + return self._preprocess(videos, **filtered_kwargs) + + +__all__ = ["Eagle25VLImageProcessorFast"] diff --git a/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py new file mode 100755 index 00000000..5a66cfbc --- /dev/null +++ b/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py @@ -0,0 +1,395 @@ +# -------------------------------------------------------- +# NVIDIA +# Copyright (c) 2025 NVIDIA +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import inspect + +import torch +import torch.utils.checkpoint as cp +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM +from transformers.models.siglip.modeling_siglip import SiglipVisionModel +from transformers.utils import add_start_docstrings, logging + +from .configuration_eagle2_5_vl import Eagle25VLConfig + +logger = logging.get_logger(__name__) + + +# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1 +EAGLE2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Eagle25VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.", + EAGLE2_5_VL_START_DOCSTRING, +) +class Eagle25VLPreTrainedModel(PreTrainedModel): + config_class = Eagle25VLConfig + base_model_prefix = "model" + main_input_name = "input_ids" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Qwen2DecoderLayer", + "LlamaDecoderLayer", + "Siglip2EncoderLayer", + "SiglipEncoderLayer", + ] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + _supports_static_cache = True + _supports_quantized_cache = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear | nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin): + config_class = Eagle25VLConfig + + def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None): + super().__init__(config) + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + if config.use_pixel_shuffle: + self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2)) + else: + self.num_image_token = int((image_size // patch_size) ** 2) + + self.select_layer = config.select_layer + self.downsample_ratio = config.downsample_ratio + self.loss_version = config.loss_version + self.mlp_checkpoint = config.mlp_checkpoint + self.use_pixel_shuffle = config.use_pixel_shuffle + self.mlp_connector_layers = config.mlp_connector_layers + logger.info(f"num_image_token: {self.num_image_token}") + logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}") + if vision_model is not None: + self.vision_model = vision_model + else: + if config.vision_config.model_type == "siglip_vision_model": + config.vision_config._attn_implementation = "flash_attention_2" + self.vision_model = SiglipVisionModel(config.vision_config) + else: + raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.") + + if language_model is not None: + self.language_model = language_model + else: + if config.text_config.architectures[0] == "LlamaForCausalLM": + self.language_model = LlamaForCausalLM(config.text_config) + elif config.text_config.architectures[0] == "Phi3ForCausalLM": + raise NotImplementedError("Phi3 is not implemented.") + # self.language_model = Phi3ForCausalLM(config.text_config) + elif config.text_config.architectures[0] == "Qwen2ForCausalLM": + assert config.text_config._attn_implementation == "flash_attention_2", ( + f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}" + ) + self.language_model = Qwen2ForCausalLM(config.text_config) + elif config.text_config.architectures[0] == "Qwen3ForCausalLM": + self.language_model = Qwen3ForCausalLM(config.text_config) + else: + raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.") + + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.text_config.hidden_size + + if config.mlp_connector_layers == 2: + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size), + ) + elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle: + self.mlp1 = nn.Sequential( + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), + ) + elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle: + self.mlp1 = nn.Sequential( + nn.Linear(vit_hidden_size, llm_hidden_size), + ) + else: + raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.") + + self.image_token_index = config.image_token_index + self.neftune_alpha = None + + if config.use_backbone_lora: + self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora) + + self.use_llm_lora = config.use_llm_lora + if config.use_llm_lora: + self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora) + + self.check_forward_kwargs() + + def check_forward_kwargs(self): + # We intentionally avoid using **kwargs in forward because Hugging Face Transformers + # has special handling for functions with **kwargs parameters that would affect + # how our model is processed during training and inference. + forward_params = inspect.signature(self.forward).parameters + assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()) + + def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): + lora_config = LoraConfig( + r=r, + target_modules=[ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.out_proj", + "mlp.fc1", + "mlp.fc2", + ], + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + self.vision_model = get_peft_model(self.vision_model, lora_config) + self.vision_model.print_trainable_parameters() + + def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): + lora_config = LoraConfig( + r=r, + target_modules=[ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.down_proj", + "mlp.up_proj", + ], + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + task_type="CAUSAL_LM", + ) + self.language_model = get_peft_model(self.language_model, lora_config) + self.language_model.enable_input_require_grads() + self.language_model.print_trainable_parameters() + self.use_llm_lora = True + + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + image_flags: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + num_tiles_list: list[torch.Tensor] | None = None, + ) -> tuple | CausalLMOutputWithPast: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_embeds = self.language_model.get_input_embeddings()(input_ids) + + vit_embeds = self.extract_feature(pixel_values) + + if image_flags is not None: + image_flags = image_flags.view(-1) + vit_embeds = vit_embeds[image_flags == 1] + + b, n, c = input_embeds.shape + input_embeds = input_embeds.reshape(b * n, c) + + input_ids = input_ids.reshape(b * n) + selected = input_ids == self.image_token_index + try: + input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c) + except Exception as e: + vit_embeds = vit_embeds.reshape(-1, c) + print( + f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, " + f"vit_embeds.shape={vit_embeds.shape}" + ) + n_token = selected.sum() + input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] + + input_embeds = input_embeds.reshape(b, n, c) + + outputs = self.language_model( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + logits = outputs.logits + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) + + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + if self.select_layer == -1: + vit_embeds = self.vision_model( + pixel_values=pixel_values, output_hidden_states=False, return_dict=True + ) + if hasattr(vit_embeds, "last_hidden_state"): + vit_embeds = vit_embeds.last_hidden_state + + else: + vit_embeds = self.vision_model( + pixel_values=pixel_values, output_hidden_states=True, return_dict=True + ).hidden_states[self.select_layer] + + if self.use_pixel_shuffle: + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle( + vit_embeds, scale_factor=self.downsample_ratio + ) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096]) + vit_embeds = vit_embeds.reshape( + vit_embeds.shape[0], -1, vit_embeds.shape[-1] + ) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096]) + + if self.mlp_checkpoint and vit_embeds.requires_grad: + vit_embeds = cp.checkpoint(self.mlp1, vit_embeds) + else: + vit_embeds = self.mlp1(vit_embeds) + + return vit_embeds + + @torch.no_grad() + def generate( + self, + pixel_values: torch.FloatTensor | None = None, + input_ids: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + visual_features: torch.FloatTensor | None = None, + generation_config: GenerationConfig | None = None, + output_hidden_states: bool | None = None, + image_sizes: list[tuple[int, int]] | None = None, + **generate_kwargs, + ) -> torch.LongTensor: + if pixel_values is not None: + if visual_features is not None: + vit_embeds = visual_features + else: + vit_embeds = self.extract_feature(pixel_values) + + input_embeds = self.language_model.get_input_embeddings()(input_ids) + b, n, c = input_embeds.shape + input_embeds = input_embeds.reshape(b * n, c) + + input_ids = input_ids.reshape(b * n) + selected = input_ids == self.config.image_token_index + assert selected.sum() != 0 + input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device) + + input_embeds = input_embeds.reshape(b, n, c) + else: + input_embeds = self.language_model.get_input_embeddings()(input_ids) + + if "use_cache" not in generate_kwargs: + generate_kwargs["use_cache"] = True + + outputs = self.language_model.generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + generation_config=generation_config, + output_hidden_states=output_hidden_states, + **generate_kwargs, + ) + + return outputs + + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder + def get_decoder(self): + return self.language_model.get_decoder() diff --git a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py new file mode 100755 index 00000000..27f9b334 --- /dev/null +++ b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py @@ -0,0 +1,518 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Eagle25VL. +copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py +""" + +import base64 +import os +import re +from io import BytesIO + +import requests +import torch +from PIL import Image +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.utils import logging +from transformers.video_utils import VideoInput + +logger = logging.get_logger(__name__) + + +FRAME_FACTOR = 2 +FPS = 2.0 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 256 + + +def to_rgb(pil_image: Image.Image) -> Image.Image: + if pil_image.mode == "RGBA": + white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) + white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask + return white_background + else: + return pil_image.convert("RGB") + + +def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image: + image = ele["image"] if "image" in ele else ele["image_url"] + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + response = requests.get(image, stream=True, timeout=10) + image_obj = Image.open(BytesIO(response.content)) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError( + f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" + ) + image = to_rgb(image_obj) + if "scale_factor" in ele: + scale_factor = ele["scale_factor"] + image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR) + return image + + +class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False): + # see processing_utils.ProcessingKwargs documentation for usage. + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + "videos_kwargs": {"max_dynamic_tiles": 1}, + } + + +class Eagle25VLProcessor(ProcessorMixin): + r""" + Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor. + + [`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the + [`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information. + + Args: + image_processor ([`LlavaOnevisionImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + num_image_tokens (`int`, *optional*): + Number of image tokens for one imagethat will be returned by vision tower. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Should be same as in model's config + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + image_token (`str`, *optional*, defaults to `""`): + Special token used to denote image location. + video_token (`str`, *optional*, defaults to `"