forked from tangger/lerobot
feat(policies): add Nvidia Gr00t N1.5 model (#2292)
* feat(policies): add Nvidia Gr00t N1.5 model Co-authored-by: lbenhorin <lbenhorin@nvidia.com> Co-authored-by: Aravindh <aravindhs@nvidia.com> Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com> Co-authored-by: youliangt <youliangt@nvidia.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Jade Choghari <chogharijade@gmail.com> * fix(docs): add groot to index Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com> --------- Co-authored-by: lbenhorin <lbenhorin@nvidia.com> Co-authored-by: Aravindh <aravindhs@nvidia.com> Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com> Co-authored-by: youliangt <youliangt@nvidia.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>
This commit is contained in:
2
.github/workflows/full_tests.yml
vendored
2
.github/workflows/full_tests.yml
vendored
@@ -78,7 +78,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv --maxfail=10
|
||||
|
||||
@@ -37,6 +37,8 @@
|
||||
title: π₀ (Pi0)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: Nvidia Gr00t N1.5
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: il_sim
|
||||
|
||||
122
docs/source/groot.mdx
Normal file
122
docs/source/groot.mdx
Normal file
@@ -0,0 +1,122 @@
|
||||
# Gr00t N1.5 Policy
|
||||
|
||||
Gr00t N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
|
||||
## Model Overview
|
||||
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
|
||||
|
||||
- Real captured data from robots.
|
||||
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
|
||||
- Internet-scale video data.
|
||||
|
||||
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
As of today, Gr00t N1.5 requires flash attention for it's internal working.
|
||||
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
3. Install LeRobot by running:
|
||||
|
||||
```bash
|
||||
pip install lerobot[groot] # consider also installing libero,dev and test tags
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use Gr00t in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
Here's a complete training command for finetuning the base Gr00t model on your own dataset:
|
||||
|
||||
```bash
|
||||
# Using a multi-GPU setup
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=$NUM_GPUS \
|
||||
$(which lerobot-train) \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--save_checkpoint=true \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--steps=$NUM_STEPS \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--log_freq=$LOG_FREQ \
|
||||
--policy.push_to_hub=true \
|
||||
--policy.type=groot \
|
||||
--policy.repo_id=$REPO_ID \
|
||||
--policy.tune_diffusion_model=false \
|
||||
--dataset.repo_id=$DATASET_ID \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--job_name=$JOB_NAME
|
||||
```
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
Gr00t has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the Gr00t N1.5 model for 20k steps on the Libero dataset and compared the results to the Gr00t reference results.
|
||||
|
||||
| Benchmark | LeRobot Implementation | Gr00t Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82% | 92.0% |
|
||||
| **Libero Object** | 99% | 92.0% |
|
||||
| **Libero Long** | 72% | 76.0% |
|
||||
| **Average** | 84% | 87.0% |
|
||||
|
||||
These results demonstrate Gr00t's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow our by following for the downstream hardware task, you can follow our instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
--robot.id=bimanual_follower \
|
||||
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
|
||||
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [Gr00t repository](https://github.com/NVIDIA/Isaac-GR00T).
|
||||
27
docs/source/policy_groot_README.md
Normal file
27
docs/source/policy_groot_README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
## Repository
|
||||
|
||||
Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{gr00tn1_2025,
|
||||
archivePrefix = {arxiv},
|
||||
eprint = {2503.14734},
|
||||
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
|
||||
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
|
||||
month = {March},
|
||||
year = {2025},
|
||||
booktitle = {ArXiv Preprint},
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
@@ -118,6 +118,17 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
|
||||
# Policies
|
||||
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"peft>=0.13.0,<1.0.0",
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"safetensors>=0.4.3,<1.0.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -145,6 +156,7 @@ all = [
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -243,6 +255,7 @@ default.extend-ignore-identifiers-re = [
|
||||
"pn",
|
||||
"ser",
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
]
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
@@ -29,4 +30,5 @@ __all__ = [
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -101,6 +102,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
return GrootPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -142,6 +147,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
@@ -199,6 +206,27 @@ def make_pre_post_processors(
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
# GROOT handles normalization in groot_pack_inputs_v3 step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
preprocessor_overrides = {}
|
||||
postprocessor_overrides = {}
|
||||
preprocessor_overrides["groot_pack_inputs_v3"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
}
|
||||
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features["action"].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
"env_action_dim": env_action_dim,
|
||||
}
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -293,6 +321,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
processors = make_groot_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
|
||||
1
src/lerobot/policies/groot/README.md
Symbolic link
1
src/lerobot/policies/groot/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_groot_README.md
|
||||
21
src/lerobot/policies/groot/__init__.py
Normal file
21
src/lerobot/policies/groot/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Nvidia and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_groot import GrootConfig
|
||||
from .modeling_groot import GrootPolicy
|
||||
from .processor_groot import make_groot_pre_post_processors
|
||||
|
||||
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
14
src/lerobot/policies/groot/action_head/__init__.py
Normal file
14
src/lerobot/policies/groot/action_head/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
54
src/lerobot/policies/groot/action_head/action_encoder.py
Normal file
54
src/lerobot/policies/groot/action_head/action_encoder.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
b, t = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
370
src/lerobot/policies/groot/action_head/cross_attention_dit.py
Executable file
370
src/lerobot/policies/groot/action_head/cross_attention_dit.py
Executable file
@@ -0,0 +1,370 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: int | None = None,
|
||||
activation_fn: str = "geglu",
|
||||
attention_bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: str | None = None,
|
||||
num_positional_embeddings: int | None = None,
|
||||
ff_inner_dim: int | None = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.norm_type = norm_type
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
if final_dropout:
|
||||
self.final_dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.final_dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
temb: torch.LongTensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# 0. Self-Attention
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: int | None = 1000,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "ada_norm",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: str | None = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
cross_attention_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Timestep encoder
|
||||
self.timestep_encoder = TimestepEncoder(
|
||||
embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype
|
||||
)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(self.config.num_layers):
|
||||
use_self_attn = idx % 2 == 1 and interleave_self_attention
|
||||
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
|
||||
|
||||
all_blocks += [
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
cross_attention_dim=curr_cross_attention_dim,
|
||||
)
|
||||
]
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
# Output blocks
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Encode timesteps
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1 and self.config.interleave_self_attention:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
# Output processing
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
else:
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: int | None = 1000,
|
||||
upcast_attention: bool = False,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: str | None = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for _idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(hidden_states)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if return_all_hidden_states:
|
||||
return hidden_states, all_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
@@ -0,0 +1,406 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = None
|
||||
|
||||
from lerobot.policies.groot.action_head.action_encoder import (
|
||||
SinusoidalPositionalEncoding,
|
||||
swish,
|
||||
)
|
||||
|
||||
from .cross_attention_dit import DiT, SelfAttentionTransformer
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
b, t, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == b:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, t)
|
||||
else:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
|
||||
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
|
||||
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
|
||||
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
|
||||
backbone_embedding_dim: int = field(
|
||||
default=1536, metadata={"help": "Backbone embedding channel dimension."}
|
||||
)
|
||||
|
||||
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
|
||||
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
|
||||
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
|
||||
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
|
||||
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
|
||||
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
|
||||
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
|
||||
num_timestep_buckets: int = field(
|
||||
default=1000, metadata={"help": "Number of timestep discretization buckets."}
|
||||
)
|
||||
num_inference_timesteps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of inference steps for noise diffusion."},
|
||||
)
|
||||
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
|
||||
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
|
||||
tune_diffusion_model: bool = field(
|
||||
default=True, metadata={"help": "Whether to tune the diffusion model."}
|
||||
)
|
||||
load_pretrained_det_decode_layer_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained detection model."}
|
||||
)
|
||||
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
|
||||
|
||||
freeze_decode_layer: bool = field(default=False)
|
||||
expand_batch: int = field(default=None)
|
||||
use_vlln: bool = field(default=True)
|
||||
|
||||
vl_self_attention_cfg: dict = field(default=None)
|
||||
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
config_class = FlowmatchingActionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlowmatchingActionHeadConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
self.model = DiT(**config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
|
||||
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
self.vl_self_attention = (
|
||||
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
|
||||
)
|
||||
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
|
||||
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
print(f"Tune action head projector: {self.tune_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_projector and not tune_diffusion_model:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Action head trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = backbone_output["backbone_features"]
|
||||
backbone_features = self.vlln(backbone_features)
|
||||
backbone_features = self.vl_self_attention(backbone_features)
|
||||
backbone_output["backbone_features"] = backbone_features
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
# Set frozen modules to eval
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
if self.config.expand_batch is not None:
|
||||
for k, v in backbone_output.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
backbone_output[k] = expanded
|
||||
|
||||
for k, v in action_input.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
action_input[k] = expanded
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
device = vl_embs.device
|
||||
|
||||
# Get embodiment ID.
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
vl_attn_mask = backbone_output.backbone_attention_mask
|
||||
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=vl_attn_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=False, # NOTE (YL): not using flare now
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Slice out only the action portion of pred and target.
|
||||
action_mask = action_input.action_mask
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = loss.sum() / action_mask.sum()
|
||||
output_dict = {
|
||||
"loss": loss,
|
||||
}
|
||||
return BatchFeature(data=output_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Set initial actions as the sampled noise.
|
||||
batch_size = vl_embs.shape[0]
|
||||
device = vl_embs.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=vl_embs.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# Run denoising steps.
|
||||
for t in range(num_steps):
|
||||
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
# Run model forward.
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
|
||||
pred_velocity = pred[:, -self.action_horizon :]
|
||||
|
||||
# Update actions using euler integration.
|
||||
actions = actions + dt * pred_velocity
|
||||
return BatchFeature(data={"action_pred": actions})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
201
src/lerobot/policies/groot/configuration_groot.py
Normal file
201
src/lerobot/policies/groot/configuration_groot.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@dataclass
|
||||
class GrootConfig(PreTrainedConfig):
|
||||
"""Configuration for Groot policy wrapper."""
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 64
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
base_model_path: str = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
|
||||
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
|
||||
# Fine-tuning control arguments
|
||||
|
||||
# Whether to fine-tune the llm backbone
|
||||
tune_llm: bool = False
|
||||
|
||||
# Whether to fine-tune the vision tower
|
||||
tune_visual: bool = False
|
||||
|
||||
# Whether to fine-tune the projector
|
||||
tune_projector: bool = True
|
||||
|
||||
# Whether to fine-tune the diffusion model
|
||||
tune_diffusion_model: bool = True
|
||||
|
||||
# LoRA parameters (from groot_finetune_script.py)
|
||||
# Rank for the LORA model. If 0, no LORA will be used.
|
||||
lora_rank: int = 0
|
||||
|
||||
# Alpha value for the LORA model
|
||||
lora_alpha: int = 16
|
||||
|
||||
# Dropout rate for the LORA model
|
||||
lora_dropout: float = 0.1
|
||||
|
||||
# Whether to use the full model for LORA
|
||||
lora_full_model: bool = False
|
||||
|
||||
# Training parameters (matching groot_finetune_script.py)
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
max_steps: int = 10000
|
||||
batch_size: int = 32
|
||||
dataloader_num_workers: int = 8
|
||||
report_to: str = "wandb"
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# groot_repo_path is now optional since we ported the components
|
||||
# No validation needed
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features for Groot."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"Groot policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features["observation.state"] = state_feature
|
||||
else:
|
||||
state_shape = self.input_features["observation.state"].shape
|
||||
state_dim = state_shape[0] if state_shape else 0
|
||||
if state_dim > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
|
||||
f"Either reduce state dimension or increase max_state_dim in config."
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features["action"] = action_feature
|
||||
else:
|
||||
action_shape = self.output_features["action"].shape
|
||||
action_dim = action_shape[0] if action_shape else 0
|
||||
if action_dim > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
|
||||
f"Either reduce action dimension or increase max_action_dim in config."
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Return optimizer configuration."""
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Return scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
|
||||
num_decay_steps=10000, # Adjust based on training steps
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.optimizer_lr * 0.1,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
"""Return indices for delta observations (None for Groot)."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
return list(range(min(self.chunk_size, 16)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""Return indices for delta rewards (None for Groot)."""
|
||||
return None
|
||||
135
src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py
Executable file
135
src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py
Executable file
@@ -0,0 +1,135 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Eagle25VLConfig(PretrainedConfig):
|
||||
model_type = "eagle_2_5_vl"
|
||||
is_composition = True
|
||||
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-4,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
loss_version="v1",
|
||||
min_dynamic_tiles=1,
|
||||
max_dynamic_tiles=6,
|
||||
mlp_checkpoint=False,
|
||||
initializer_range=0.02,
|
||||
_attn_implementation="flash_attention_2",
|
||||
_attn_implementation_autoset=False,
|
||||
llm_config=None,
|
||||
image_token_index=None,
|
||||
use_pixel_shuffle=True,
|
||||
mlp_connector_layers=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"model_type": "siglip_vision_model"}
|
||||
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {"architectures": ["Qwen2ForCausalLM"]}
|
||||
logger.info(
|
||||
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
if vision_config["model_type"] == "siglip_vision_model":
|
||||
self.vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
|
||||
|
||||
if text_config["architectures"][0] == "LlamaForCausalLM":
|
||||
self.text_config = LlamaConfig(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
|
||||
self.text_config = Qwen2Config(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
|
||||
self.text_config = Qwen3Config(**text_config)
|
||||
else:
|
||||
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.mlp_checkpoint = mlp_checkpoint
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.loss_version = loss_version
|
||||
self.initializer_range = initializer_range
|
||||
self.min_dynamic_tiles = min_dynamic_tiles
|
||||
self.max_dynamic_tiles = max_dynamic_tiles
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self._attn_implementation = _attn_implementation
|
||||
self._attn_implementation_autoset = _attn_implementation_autoset
|
||||
self.image_token_index = image_token_index
|
||||
self.use_pixel_shuffle = use_pixel_shuffle
|
||||
self.mlp_connector_layers = mlp_connector_layers
|
||||
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
|
||||
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["pad2square"] = self.pad2square
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["min_dynamic_tiles"] = self.min_dynamic_tiles
|
||||
output["max_dynamic_tiles"] = self.max_dynamic_tiles
|
||||
output["tie_word_embeddings"] = self.tie_word_embeddings
|
||||
output["_attn_implementation"] = self._attn_implementation
|
||||
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
|
||||
output["use_pixel_shuffle"] = self.use_pixel_shuffle
|
||||
output["mlp_connector_layers"] = self.mlp_connector_layers
|
||||
return output
|
||||
@@ -0,0 +1,504 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from typing import Optional
|
||||
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
||||
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_flat_list_of_images,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from transformers.image_utils import pil_torch_interpolation_mapping
|
||||
else:
|
||||
from torchvision.transforms import functional as F # noqa: N812
|
||||
|
||||
|
||||
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
|
||||
"""Crop the given numpy array.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
|
||||
left (int): The left coordinate of the crop box.
|
||||
top (int): The top coordinate of the crop box.
|
||||
right (int): The right coordinate of the crop box.
|
||||
bottom (int): The bottom coordinate of the crop box.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cropped image.
|
||||
"""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
|
||||
|
||||
if img.ndim not in [2, 3]:
|
||||
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
|
||||
|
||||
img_height = img.shape[1]
|
||||
img_width = img.shape[2]
|
||||
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
||||
raise ValueError("Crop coordinates out of bounds")
|
||||
|
||||
if top >= bottom or left >= right:
|
||||
raise ValueError("Invalid crop coordinates")
|
||||
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
pad_during_tiling: bool | None
|
||||
do_pad: bool | None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method. Not used for processing videos.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 448, "width": 448}
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = True
|
||||
max_dynamic_tiles = 12
|
||||
min_dynamic_tiles = 1
|
||||
use_thumbnail = True
|
||||
pad_during_tiling = False
|
||||
valid_kwargs = Eagle25VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
max_dynamic_tiles (`int`, *optional*):
|
||||
The maximum number of dynamic tiles to use for processing high resolution images.
|
||||
min_dynamic_tiles (`int`, *optional*):
|
||||
The minimum number of dynamic tiles to use for processing high resolution images.
|
||||
use_thumbnail (`bool`, *optional*):
|
||||
Whether to use a thumbnail for processing high resolution images.
|
||||
pad_during_tiling (`bool`, *optional*):
|
||||
Whether to pad the image during tiling.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
|
||||
# NOTE(YL): we will overload the preprocess method to add the image_flags
|
||||
# def preprocess(
|
||||
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
|
||||
# ) -> BatchFeature:
|
||||
# return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
expected_ndims: int = 3,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
expected_ndims (`int`, *optional*, defaults to 3):
|
||||
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_flat_list_of_images(images)
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
target_resolution: tuple,
|
||||
interpolation: "F.InterpolationMode",
|
||||
input_data_format: ChannelDimension,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image ("torch.Tensor"):
|
||||
The input image.
|
||||
target_resolution (tuple):
|
||||
The target resolution (height, width) of the image.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
"torch.Tensor": The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
|
||||
return resized_image
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
"""
|
||||
previous version mainly focus on ratio.
|
||||
We also consider area ratio here.
|
||||
"""
|
||||
best_factor = float("-inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
|
||||
"""
|
||||
new area > 60% of original image area is enough.
|
||||
"""
|
||||
factor_based_on_area_n_ratio = min(
|
||||
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
|
||||
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
|
||||
|
||||
if factor_based_on_area_n_ratio > best_factor:
|
||||
best_factor = factor_based_on_area_n_ratio
|
||||
best_ratio = ratio
|
||||
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: "F.InterpolationMode",
|
||||
pad_during_tiling: bool,
|
||||
) -> list["torch.Tensor"]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = tile_size * target_aspect_ratio[0]
|
||||
target_height = tile_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if pad_during_tiling:
|
||||
resized_image = self._resize_for_patching(
|
||||
image,
|
||||
(target_height, target_width),
|
||||
interpolation=interpolation,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
padded_image = self._pad_for_patching(
|
||||
resized_image,
|
||||
(target_height, target_width),
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
image_used_to_split = padded_image
|
||||
else:
|
||||
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
|
||||
|
||||
processed_tiles = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // tile_size)) * tile_size,
|
||||
(i // (target_width // tile_size)) * tile_size,
|
||||
((i % (target_width // tile_size)) + 1) * tile_size,
|
||||
((i // (target_width // tile_size)) + 1) * tile_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
|
||||
processed_tiles.append(split_img)
|
||||
assert len(processed_tiles) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_tiles) != 1:
|
||||
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
|
||||
processed_tiles.append(thumbnail_img)
|
||||
|
||||
return processed_tiles
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[torch.Tensor]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
|
||||
Returns:
|
||||
List[`torch.Tensor`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
do_pad: bool,
|
||||
return_tensors: str | TensorType | None,
|
||||
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
else:
|
||||
size_tuple = (size.shortest_edge, size.shortest_edge)
|
||||
|
||||
# Determine the patch size
|
||||
if crop_size and crop_size.height:
|
||||
tile_size = crop_size.height
|
||||
elif size and size.height:
|
||||
tile_size = size.height
|
||||
else:
|
||||
tile_size = size.shortest_edge
|
||||
|
||||
for image in images:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
min_num=min_dynamic_tiles,
|
||||
max_num=max_dynamic_tiles,
|
||||
size=size_tuple,
|
||||
tile_size=tile_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
interpolation=interpolation,
|
||||
pad_during_tiling=pad_during_tiling,
|
||||
)
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
if do_center_crop:
|
||||
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean,
|
||||
image_std,
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(
|
||||
processed_image_patches_grouped, grouped_image_patches_index
|
||||
)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.append(processed_image_patches)
|
||||
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
|
||||
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
|
||||
)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if videos is not None:
|
||||
videos = self._prepare_image_like_inputs(
|
||||
images=videos,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
resample = kwargs.pop("resample", self.resample)
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample]
|
||||
if isinstance(resample, PILImageResampling | int)
|
||||
else resample
|
||||
)
|
||||
|
||||
# Filter kwargs to only include those accepted by _preprocess
|
||||
valid_preprocess_kwargs = {
|
||||
"do_resize",
|
||||
"size",
|
||||
"max_dynamic_tiles",
|
||||
"min_dynamic_tiles",
|
||||
"use_thumbnail",
|
||||
"pad_during_tiling",
|
||||
"interpolation",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_pad",
|
||||
"return_tensors",
|
||||
"pad_size",
|
||||
"disable_grouping",
|
||||
}
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
|
||||
if images is not None:
|
||||
return self._preprocess(images, **filtered_kwargs)
|
||||
elif videos is not None:
|
||||
return self._preprocess(videos, **filtered_kwargs)
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLImageProcessorFast"]
|
||||
395
src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py
Executable file
395
src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py
Executable file
@@ -0,0 +1,395 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as cp
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
||||
from transformers.utils import add_start_docstrings, logging
|
||||
|
||||
from .configuration_eagle2_5_vl import Eagle25VLConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
|
||||
EAGLE2_5_VL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Eagle25VLConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
|
||||
EAGLE2_5_VL_START_DOCSTRING,
|
||||
)
|
||||
class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
config_class = Eagle25VLConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Qwen2DecoderLayer",
|
||||
"LlamaDecoderLayer",
|
||||
"Siglip2EncoderLayer",
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear | nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
|
||||
config_class = Eagle25VLConfig
|
||||
|
||||
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
|
||||
super().__init__(config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
if config.use_pixel_shuffle:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
|
||||
else:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2)
|
||||
|
||||
self.select_layer = config.select_layer
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.loss_version = config.loss_version
|
||||
self.mlp_checkpoint = config.mlp_checkpoint
|
||||
self.use_pixel_shuffle = config.use_pixel_shuffle
|
||||
self.mlp_connector_layers = config.mlp_connector_layers
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
|
||||
if vision_model is not None:
|
||||
self.vision_model = vision_model
|
||||
else:
|
||||
if config.vision_config.model_type == "siglip_vision_model":
|
||||
config.vision_config._attn_implementation = "flash_attention_2"
|
||||
self.vision_model = SiglipVisionModel(config.vision_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
|
||||
|
||||
if language_model is not None:
|
||||
self.language_model = language_model
|
||||
else:
|
||||
if config.text_config.architectures[0] == "LlamaForCausalLM":
|
||||
self.language_model = LlamaForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
|
||||
raise NotImplementedError("Phi3 is not implemented.")
|
||||
# self.language_model = Phi3ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
assert config.text_config._attn_implementation == "flash_attention_2", (
|
||||
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
|
||||
)
|
||||
self.language_model = Qwen2ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
||||
self.language_model = Qwen3ForCausalLM(config.text_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
if config.mlp_connector_layers == 2:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size, llm_hidden_size),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
|
||||
|
||||
self.image_token_index = config.image_token_index
|
||||
self.neftune_alpha = None
|
||||
|
||||
if config.use_backbone_lora:
|
||||
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
||||
|
||||
self.use_llm_lora = config.use_llm_lora
|
||||
if config.use_llm_lora:
|
||||
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
||||
|
||||
self.check_forward_kwargs()
|
||||
|
||||
def check_forward_kwargs(self):
|
||||
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
|
||||
# has special handling for functions with **kwargs parameters that would affect
|
||||
# how our model is processed during training and inference.
|
||||
forward_params = inspect.signature(self.forward).parameters
|
||||
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
|
||||
|
||||
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.out_proj",
|
||||
"mlp.fc1",
|
||||
"mlp.fc2",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
self.vision_model.print_trainable_parameters()
|
||||
|
||||
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.o_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.up_proj",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.language_model = get_peft_model(self.language_model, lora_config)
|
||||
self.language_model.enable_input_require_grads()
|
||||
self.language_model.print_trainable_parameters()
|
||||
self.use_llm_lora = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
image_flags: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_tiles_list: list[torch.Tensor] | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
if image_flags is not None:
|
||||
image_flags = image_flags.view(-1)
|
||||
vit_embeds = vit_embeds[image_flags == 1]
|
||||
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.image_token_index
|
||||
try:
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
|
||||
except Exception as e:
|
||||
vit_embeds = vit_embeds.reshape(-1, c)
|
||||
print(
|
||||
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
|
||||
f"vit_embeds.shape={vit_embeds.shape}"
|
||||
)
|
||||
n_token = selected.sum()
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
|
||||
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
)
|
||||
if hasattr(vit_embeds, "last_hidden_state"):
|
||||
vit_embeds = vit_embeds.last_hidden_state
|
||||
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
|
||||
if self.use_pixel_shuffle:
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
||||
|
||||
if self.mlp_checkpoint and vit_embeds.requires_grad:
|
||||
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
|
||||
else:
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
|
||||
return vit_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
input_ids: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
visual_features: torch.FloatTensor | None = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
image_sizes: list[tuple[int, int]] | None = None,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
if pixel_values is not None:
|
||||
if visual_features is not None:
|
||||
vit_embeds = visual_features
|
||||
else:
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.config.image_token_index
|
||||
assert selected.sum() != 0
|
||||
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
else:
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if "use_cache" not in generate_kwargs:
|
||||
generate_kwargs["use_cache"] = True
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
518
src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py
Executable file
518
src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py
Executable file
@@ -0,0 +1,518 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Eagle25VL.
|
||||
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 256
|
||||
|
||||
|
||||
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||
if pil_image.mode == "RGBA":
|
||||
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||
return white_background
|
||||
else:
|
||||
return pil_image.convert("RGB")
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
|
||||
image = ele["image"] if "image" in ele else ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(image, stream=True, timeout=10)
|
||||
image_obj = Image.open(BytesIO(response.content))
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = to_rgb(image_obj)
|
||||
if "scale_factor" in ele:
|
||||
scale_factor = ele["scale_factor"]
|
||||
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
# see processing_utils.ProcessingKwargs documentation for usage.
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
"videos_kwargs": {"max_dynamic_tiles": 1},
|
||||
}
|
||||
|
||||
|
||||
class Eagle25VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
|
||||
|
||||
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
|
||||
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
num_image_tokens (`int`, *optional*):
|
||||
Number of image tokens for one imagethat will be returned by vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Should be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"num_image_tokens",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"images_kwargs",
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<IMG_CONTEXT>", # nosec: B107
|
||||
video_token="<IMG_CONTEXT>", # nosec: B107
|
||||
tokens_per_tile=256,
|
||||
image_placeholder="image",
|
||||
video_placeholder="video",
|
||||
image_start_token="<img>",
|
||||
image_end_token="</img>",
|
||||
**kwargs,
|
||||
):
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
self.image_placeholder = image_placeholder
|
||||
self.video_placeholder = video_placeholder
|
||||
self.tokens_per_tile = tokens_per_tile
|
||||
self.image_start_token = image_start_token
|
||||
self.image_end_token = image_end_token
|
||||
if "auto_map" in kwargs:
|
||||
self.auto_map = kwargs["auto_map"]
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def replace_media_placeholder(
|
||||
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
|
||||
):
|
||||
num_of_images_in_this_sample = 0
|
||||
num_of_videos_in_this_sample = 0
|
||||
# Regular expression pattern to match formats like <image-1> or <video-2>
|
||||
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
|
||||
unified_frame_list = []
|
||||
|
||||
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
# )
|
||||
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
# )
|
||||
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
|
||||
# "use_thumbnail", self.image_processor.use_thumbnail
|
||||
# )
|
||||
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
)
|
||||
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
)
|
||||
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
|
||||
"use_thumbnail", self.image_processor.use_thumbnail
|
||||
)
|
||||
|
||||
tile_size = self.image_processor.size.get("height", 448)
|
||||
|
||||
# Function to replace tags in a single text
|
||||
def replace_in_text(text):
|
||||
# repl callback function for each match replacement operation
|
||||
def repl(match):
|
||||
nonlocal unified_frame_list
|
||||
nonlocal num_of_images_in_this_sample
|
||||
nonlocal num_of_videos_in_this_sample
|
||||
media_type = match.group(1) # 'image' or 'video'
|
||||
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
|
||||
# Select the corresponding path based on media type
|
||||
idx_mapper = {
|
||||
0: "first",
|
||||
1: "second",
|
||||
2: "third",
|
||||
3: "fourth",
|
||||
4: "fifth",
|
||||
5: "sixth",
|
||||
6: "seventh",
|
||||
7: "eighth",
|
||||
8: "ninth",
|
||||
9: "tenth",
|
||||
}
|
||||
if media_type == "image":
|
||||
image_inputs = self.image_processor(
|
||||
images=[image_list[idx_in_list]],
|
||||
videos=None,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
unified_frame_list.append(image_inputs)
|
||||
num_of_images_in_this_sample += 1
|
||||
|
||||
elif media_type == "video":
|
||||
video_inputs = self.image_processor(
|
||||
images=None,
|
||||
videos=[video_list[idx_in_list]],
|
||||
**output_kwargs["videos_kwargs"],
|
||||
)
|
||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||
image_sizes = video_inputs["image_sizes"]
|
||||
if timestamps_list is not None and -1 not in timestamps_list:
|
||||
frame_timestamps = timestamps_list[idx_in_list]
|
||||
else:
|
||||
frame_timestamps = None
|
||||
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
|
||||
|
||||
num_of_tiles_each_frame = [
|
||||
self.get_number_tiles_based_on_image_size(
|
||||
image_size,
|
||||
video_min_dynamic_tiles,
|
||||
video_max_dynamic_tiles,
|
||||
video_use_thumbnail,
|
||||
tile_size,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
|
||||
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
|
||||
)
|
||||
|
||||
if frame_timestamps is not None:
|
||||
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
|
||||
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
|
||||
)
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
else:
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
|
||||
if sampled_fps is not None:
|
||||
special_placeholder = (
|
||||
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
|
||||
+ "".join(special_placeholder)
|
||||
)
|
||||
else:
|
||||
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
|
||||
special_placeholder
|
||||
)
|
||||
unified_frame_list.append(video_inputs)
|
||||
num_of_videos_in_this_sample += 1
|
||||
else:
|
||||
raise ValueError(f"Unknown media type: {media_type}")
|
||||
return special_placeholder
|
||||
|
||||
return pattern.sub(repl, text)
|
||||
|
||||
text = replace_in_text(text)
|
||||
if len(unified_frame_list) > 0:
|
||||
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
|
||||
else:
|
||||
pixel_values = None
|
||||
image_sizes = None
|
||||
return (
|
||||
text,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio=None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Eagle25VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text_list = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
text_list = text
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if videos is None:
|
||||
videos = []
|
||||
|
||||
pixel_values_list = []
|
||||
image_sizes_list = []
|
||||
new_sample_list = []
|
||||
image_start_idx = 0
|
||||
video_start_idx = 0
|
||||
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
|
||||
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
|
||||
for sample in text_list:
|
||||
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
|
||||
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
|
||||
(
|
||||
sample,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
) = self.replace_media_placeholder(
|
||||
sample,
|
||||
images[image_start_idx:],
|
||||
videos[video_start_idx:],
|
||||
timestamps_list,
|
||||
fps_list,
|
||||
**output_kwargs,
|
||||
)
|
||||
new_sample_list.append(sample)
|
||||
if pixel_values is not None:
|
||||
pixel_values_list.append(pixel_values)
|
||||
image_sizes_list.append(image_sizes)
|
||||
image_start_idx += num_of_images_in_this_sample
|
||||
video_start_idx += num_of_videos_in_this_sample
|
||||
|
||||
if len(pixel_values_list) > 0:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_list),
|
||||
"image_sizes": torch.cat(image_sizes_list),
|
||||
}
|
||||
else:
|
||||
image_inputs = {}
|
||||
video_inputs = {}
|
||||
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
||||
|
||||
def get_number_tiles_based_on_image_size(
|
||||
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of tiles based on the image size.
|
||||
"""
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if use_thumbnail and tiles_num > 1:
|
||||
tiles_num += 1
|
||||
return tiles_num
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# override to save video-config in a separate config file
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
outputs = super().save_pretrained(save_directory, **kwargs)
|
||||
return outputs
|
||||
|
||||
# override to load video-config from a separate config file
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
||||
if isinstance(processor, tuple):
|
||||
processor = processor[0]
|
||||
return processor
|
||||
|
||||
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def process_vision_info(
|
||||
self,
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
return_video_kwargs: bool = False,
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
|
||||
vision_infos = self.extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
video_sample_fps_list = []
|
||||
video_timestamps_list = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
if return_video_kwargs:
|
||||
return (
|
||||
image_inputs,
|
||||
video_inputs,
|
||||
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
|
||||
)
|
||||
return image_inputs, video_inputs
|
||||
|
||||
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if (
|
||||
"image" in ele
|
||||
or "image_url" in ele
|
||||
or "video" in ele
|
||||
or ele["type"] in ("image", "image_url", "video")
|
||||
):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLProcessor"]
|
||||
376
src/lerobot/policies/groot/groot_n1.py
Normal file
376
src/lerobot/policies/groot/groot_n1.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
from lerobot.policies.groot.action_head.flow_matching_action_head import (
|
||||
FlowmatchingActionHead,
|
||||
FlowmatchingActionHeadConfig,
|
||||
)
|
||||
from lerobot.policies.groot.utils import ensure_eagle_cache_ready
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
class EagleBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
|
||||
project_to_dim: int = 1536,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tune_llm: whether to tune the LLM model (default: True)
|
||||
tune_visual: whether to tune the visual model (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
|
||||
|
||||
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
|
||||
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
try:
|
||||
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
|
||||
|
||||
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
|
||||
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
if project_to_dim is not None:
|
||||
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
|
||||
else:
|
||||
self.eagle_linear = torch.nn.Identity()
|
||||
|
||||
# needed since we don't use these layers. Also saves compute
|
||||
while len(self.eagle_model.language_model.model.layers) > select_layer:
|
||||
self.eagle_model.language_model.model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual)
|
||||
|
||||
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.eagle_model.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.eagle_model.vision_model.requires_grad_(False)
|
||||
self.eagle_model.mlp1.requires_grad_(False)
|
||||
print(f"Tune backbone llm: {self.tune_llm}")
|
||||
print(f"Tune backbone visual: {self.tune_visual}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_llm and not tune_visual:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Backbone trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No backbone trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if self.eagle_model.language_model and not self.tune_llm:
|
||||
self.eagle_model.language_model.eval()
|
||||
if self.eagle_model.vision_model and not self.tune_visual:
|
||||
self.eagle_model.vision_model.eval()
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
eagle_prefix = "eagle_"
|
||||
eagle_input = {
|
||||
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
|
||||
}
|
||||
del eagle_input["image_sizes"]
|
||||
|
||||
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
|
||||
eagle_features = eagle_output.hidden_states[self.select_layer]
|
||||
|
||||
eagle_features = self.eagle_linear(eagle_features)
|
||||
return eagle_features, eagle_input["attention_mask"]
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
|
||||
|
||||
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
|
||||
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
|
||||
if self.training and self.tune_visual:
|
||||
dummy_term = torch.tensor(
|
||||
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
|
||||
)
|
||||
for param in self.eagle_model.vision_model.parameters():
|
||||
if param.requires_grad:
|
||||
dummy_term = dummy_term + 0.0 * param.sum()
|
||||
eagle_embeds = eagle_embeds + dummy_term
|
||||
|
||||
return BatchFeature(
|
||||
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
|
||||
) # [B, T2, hidden_size]
|
||||
|
||||
|
||||
BACKBONE_FEATURE_KEY = "backbone_features"
|
||||
ACTION_KEY = "action_pred"
|
||||
LOSS_KEY = "loss"
|
||||
ERROR_MSG = "Error: unexpected input/output"
|
||||
N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@dataclass
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
|
||||
|
||||
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
|
||||
|
||||
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
|
||||
|
||||
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
|
||||
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
# real model
|
||||
class GR00TN15(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
config_class = GR00TN15Config
|
||||
"""
|
||||
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
|
||||
here n is variable and can be e.g. time, 1 or user specified
|
||||
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
|
||||
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN15Config,
|
||||
local_model_path: str,
|
||||
):
|
||||
assert isinstance(config.backbone_cfg, dict)
|
||||
assert isinstance(config.action_head_cfg, dict)
|
||||
|
||||
super().__init__(config)
|
||||
self.local_model_path = local_model_path
|
||||
|
||||
self.backbone = EagleBackbone(**config.backbone_cfg)
|
||||
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
|
||||
self.action_head = FlowmatchingActionHead(action_head_cfg)
|
||||
|
||||
self.action_horizon = config.action_horizon
|
||||
self.action_dim = config.action_dim
|
||||
self.compute_dtype = config.compute_dtype
|
||||
|
||||
def validate_inputs(self, inputs):
|
||||
# NOTE -- this should be handled internally by the model
|
||||
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
|
||||
|
||||
detected_error = False
|
||||
error_msg = ERROR_MSG
|
||||
if "action" in inputs:
|
||||
action = inputs["action"]
|
||||
# In inference, action may be omitted or None; validate only when it's a tensor.
|
||||
if action is None:
|
||||
pass # allow None during inference
|
||||
elif isinstance(action, torch.Tensor):
|
||||
shape_ok = (
|
||||
len(action.shape) == 3
|
||||
and action.shape[1] == self.action_horizon
|
||||
and action.shape[2] == self.action_dim
|
||||
)
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{action.shape=}"
|
||||
detected_error = True
|
||||
else:
|
||||
# Unexpected non-tensor type provided for action
|
||||
error_msg += f"\nInvalid type for action: {type(action)}"
|
||||
detected_error = True
|
||||
|
||||
if "video" in inputs:
|
||||
video = inputs["video"]
|
||||
type_ok = isinstance(video, np.ndarray)
|
||||
dtype_ok = video.dtype == np.uint8
|
||||
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
|
||||
if not type_ok:
|
||||
error_msg += f"\n{type(video)=}"
|
||||
detected_error = True
|
||||
if not dtype_ok:
|
||||
error_msg += f"\n{video.dtype=}"
|
||||
detected_error = True
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{video.shape=}"
|
||||
detected_error = True
|
||||
|
||||
if detected_error:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
|
||||
fail_backbone = (
|
||||
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
|
||||
)
|
||||
|
||||
if fail_backbone:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
|
||||
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
|
||||
(
|
||||
LOSS_KEY in action_head_outputs and is_training
|
||||
) # there might not be an action prediction during training
|
||||
or (
|
||||
ACTION_KEY in action_head_outputs
|
||||
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
|
||||
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
|
||||
)
|
||||
)
|
||||
|
||||
if fail_action_head:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
|
||||
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
|
||||
error_msg += f"\n{self.action_horizon=}"
|
||||
error_msg += f"\n{self.action_dim=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
|
||||
return action_head_outputs
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
|
||||
return action_head_outputs
|
||||
|
||||
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
|
||||
self.validate_inputs(inputs)
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_maybe_dtype(x):
|
||||
# Cast floating tensors to a memory-efficient compute dtype when requested.
|
||||
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
|
||||
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
if getattr(self, "compute_dtype", None) == "bfloat16":
|
||||
return x.to(self.device, dtype=torch.bfloat16)
|
||||
# Fallback: preserve previous behavior if not using bf16 compute
|
||||
return x.to(self.device, dtype=self.action_head.dtype)
|
||||
# Non-floating tensors: move device only
|
||||
return x.to(self.device)
|
||||
|
||||
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
|
||||
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
|
||||
return backbone_inputs, action_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
|
||||
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
|
||||
print(f"Tune backbone vision tower: {tune_visual}")
|
||||
print(f"Tune backbone LLM: {tune_llm}")
|
||||
print(f"Tune action head projector: {tune_projector}")
|
||||
print(f"Tune action head DiT: {tune_diffusion_model}")
|
||||
|
||||
# get the current model path being downloaded
|
||||
try:
|
||||
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
|
||||
# saved in ~/.cache/huggingface/hub/
|
||||
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
|
||||
# HFValidationError, RepositoryNotFoundError
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
print(
|
||||
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
|
||||
)
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path, local_model_path=local_model_path, **kwargs
|
||||
)
|
||||
|
||||
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
|
||||
)
|
||||
return pretrained_model
|
||||
198
src/lerobot/policies/groot/modeling_groot.py
Normal file
198
src/lerobot/policies/groot/modeling_groot.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T components where possible
|
||||
without porting their code. The intent is to:
|
||||
|
||||
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
|
||||
- Optionally align action horizon similar to gr00t_finetune.py
|
||||
- Expose predict_action via GR00T model.get_action
|
||||
- Provide a training forward that can call the GR00T model forward if batch
|
||||
structure matches.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1 import GR00TN15
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class GrootPolicy(PreTrainedPolicy):
|
||||
"""Wrapper around external Groot model for LeRobot integration."""
|
||||
|
||||
name = "groot"
|
||||
config_class = GrootConfig
|
||||
|
||||
def __init__(self, config: GrootConfig):
|
||||
"""Initialize Groot policy wrapper."""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Initialize GR00T model using ported components
|
||||
self._groot_model = self._create_groot_model()
|
||||
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T model using Isaac-GR00T API.
|
||||
|
||||
This is only called when creating a NEW policy (not when loading from checkpoint).
|
||||
|
||||
Steps (delegating to Isaac-GR00T):
|
||||
1) Download and load pretrained model via GR00TN15.from_pretrained
|
||||
2) Align action horizon with data_config if provided
|
||||
"""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
model = GR00TN15.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.base_model_path,
|
||||
tune_llm=self.config.tune_llm,
|
||||
tune_visual=self.config.tune_visual,
|
||||
tune_projector=self.config.tune_projector,
|
||||
tune_diffusion_model=self.config.tune_diffusion_model,
|
||||
)
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Training forward pass.
|
||||
|
||||
Delegates to Isaac-GR00T model.forward when inputs are compatible.
|
||||
"""
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.forward(groot_inputs)
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch
|
||||
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
original_action_dim = self.config.output_features["action"].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select single action from action queue."""
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
664
src/lerobot/policies/groot/processor_groot.py
Normal file
664
src/lerobot/policies/groot/processor_groot.py
Normal file
@@ -0,0 +1,664 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, ProcessorMixin
|
||||
else:
|
||||
AutoProcessor = None
|
||||
ProcessorMixin = object
|
||||
|
||||
from lerobot.configs.types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
PolicyFeature,
|
||||
)
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
HF_LEROBOT_HOME,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
# Defaults for Eagle processor locations
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
def make_groot_pre_post_processors(
|
||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create preprocessor and postprocessor for Groot policy.
|
||||
|
||||
This creates a processing pipeline that transforms LeRobot data format into
|
||||
the format expected by Isaac-GR00T models:
|
||||
|
||||
Preprocessing steps:
|
||||
1. Optional key renaming (dataset-specific key mapping)
|
||||
2. Add batch dimension to unbatched data
|
||||
3. Pack video/state/action/language/embodiment and apply optional min-max normalization before padding
|
||||
4. Encode video+language with Eagle VLM into intermediate eagle_content
|
||||
5. Collate eagle_content into batched eagle_* tensors
|
||||
6. Move tensors to device (GPU)
|
||||
|
||||
NOTE: We optionally apply min-max normalization to STATE and ACTION using
|
||||
dataset-provided statistics prior to padding, mapping values to [-1, 1].
|
||||
This mirrors SO100-style preprocessing and keeps scales consistent with GR00T.
|
||||
|
||||
Args:
|
||||
config: Groot configuration containing data_config, embodiment_tag, etc.
|
||||
dataset_stats: Optional per-key min/max statistics for normalization before padding.
|
||||
|
||||
Returns:
|
||||
Tuple of (preprocessor, postprocessor) pipelines
|
||||
"""
|
||||
|
||||
# Get horizon/dimension parameters from config
|
||||
# These should match the config used for the pretrained model
|
||||
# Default values match most GR00T configs (state_horizon=1, action_horizon=16)
|
||||
state_horizon = 1
|
||||
# CRITICAL: Pretrained GR00T models use action_horizon=16 max!
|
||||
# The model architecture hardcodes this limit
|
||||
action_horizon = min(config.chunk_size, 16)
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
|
||||
# Pass raw dataset_stats; normalization will occur inside pack step before padding
|
||||
padded_stats = dataset_stats or {}
|
||||
|
||||
# Define feature specs for optional normalization steps
|
||||
_features: dict[str, PolicyFeature] = {
|
||||
# Observation features (only add those we may normalize)
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
|
||||
# Action feature
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
|
||||
}
|
||||
|
||||
# Normalize STATE and ACTION with min_max (SO100-like default)
|
||||
_norm_map = {
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Determine env action dimension from config (simple, object-like PolicyFeature)
|
||||
try:
|
||||
env_action_dim = int(config.output_features["action"].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
# 1. Rename keys if needed (e.g., dataset-specific camera names)
|
||||
# Leave empty for now - add mappings if your dataset uses different key names
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
# 2. Add batch dimension for single samples
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# 3. Pack video/state/action/language/embodiment; apply optional min-max normalization before padding
|
||||
GrootPackInputsStep(
|
||||
state_horizon=state_horizon,
|
||||
action_horizon=action_horizon,
|
||||
max_state_dim=max_state_dim,
|
||||
max_action_dim=max_action_dim,
|
||||
language_key="task",
|
||||
formalize_language=False,
|
||||
embodiment_tag=config.embodiment_tag,
|
||||
normalize_min_max=True,
|
||||
stats=padded_stats,
|
||||
),
|
||||
# 4. Eagle encode (creates eagle_content)
|
||||
GrootEagleEncodeStep(
|
||||
tokenizer_assets_repo=config.tokenizer_assets_repo,
|
||||
),
|
||||
# 5. Collate eagle_content -> eagle_* tensors
|
||||
GrootEagleCollateStep(
|
||||
tokenizer_assets_repo=config.tokenizer_assets_repo,
|
||||
),
|
||||
# 6. Move to device
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
# Postprocessing: slice to env action dim and unnormalize to env scale, then move to CPU
|
||||
output_steps: list[ProcessorStep] = [
|
||||
GrootActionUnpackUnnormalizeStep(
|
||||
env_action_dim=env_action_dim,
|
||||
stats=padded_stats,
|
||||
normalize_min_max=True,
|
||||
),
|
||||
# Finally, move to CPU for env interaction
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# GR00T specific processor steps
|
||||
|
||||
|
||||
def _to_uint8_np_bhwc(img_t: torch.Tensor) -> np.ndarray:
|
||||
# img_t: (B, C, H, W) float in [0,1] or uint8
|
||||
if img_t.dtype.is_floating_point:
|
||||
img_t = (img_t.clamp(0, 1) * 255.0).to(torch.uint8)
|
||||
return rearrange(img_t.cpu().numpy(), "b c h w -> b h w c")
|
||||
|
||||
|
||||
def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO) -> ProcessorMixin:
|
||||
# Validate that the cache directory is ready. If not, instruct the user.
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
required = [
|
||||
cache_dir / "processor_config.json",
|
||||
cache_dir / "preprocessor_config.json",
|
||||
cache_dir / "image_processing_eagle2_5_vl_fast.py",
|
||||
]
|
||||
if not all(p.exists() for p in required):
|
||||
raise FileNotFoundError(
|
||||
f"[GROOT] Eagle processor cache at '{cache_dir}' is not populated. "
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_pack_inputs_v3")
|
||||
class GrootPackInputsStep(ProcessorStep):
|
||||
state_horizon: int = 1
|
||||
action_horizon: int = 16
|
||||
max_state_dim: int = 64
|
||||
max_action_dim: int = 32
|
||||
language_key: str = "task"
|
||||
formalize_language: bool = False
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
embodiment_mapping: dict[str, int] = field(
|
||||
default_factory=lambda: {
|
||||
"new_embodiment": 31, # Match original GR00T EMBODIMENT_TAG_MAPPING
|
||||
"oxe_droid": 17,
|
||||
"agibot_genie1": 26,
|
||||
"gr1": 24,
|
||||
"so100": 2,
|
||||
"unitree_g1": 3,
|
||||
}
|
||||
)
|
||||
# Min-max normalization (SO100-like) applied BEFORE padding
|
||||
normalize_min_max: bool = True
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
|
||||
def _align_vec(vec: Any, target_dim: int, *, default: float) -> torch.Tensor:
|
||||
t = torch.as_tensor(vec)
|
||||
t = t.flatten().to(
|
||||
dtype=torch.float32,
|
||||
device=next(
|
||||
(v.device for v in obs.values() if isinstance(v, torch.Tensor)), torch.device("cpu")
|
||||
),
|
||||
)
|
||||
d = int(t.shape[-1]) if t.numel() > 0 else 0
|
||||
if d == target_dim:
|
||||
return t
|
||||
if d < target_dim:
|
||||
pad = torch.full((target_dim - d,), default, dtype=t.dtype, device=t.device)
|
||||
return torch.cat([t, pad], dim=0)
|
||||
return t[:target_dim]
|
||||
|
||||
def _min_max_norm(x: torch.Tensor, key: str) -> torch.Tensor:
|
||||
if not self.normalize_min_max:
|
||||
return x
|
||||
if self.stats is None or key not in self.stats:
|
||||
return x
|
||||
stats_k = self.stats[key]
|
||||
last_dim = x.shape[-1]
|
||||
min_v = _align_vec(stats_k.get("min", torch.zeros(last_dim)), last_dim, default=0.0)
|
||||
max_v = _align_vec(stats_k.get("max", torch.ones(last_dim)), last_dim, default=1.0)
|
||||
denom = max_v - min_v
|
||||
mask = denom != 0
|
||||
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
|
||||
mapped = 2 * (x - min_v) / safe_denom - 1
|
||||
return torch.where(mask, mapped, torch.zeros_like(mapped))
|
||||
|
||||
# 1) Video (B, T=1, V, H, W, C) uint8
|
||||
img_keys = sorted([k for k in obs if k.startswith("observation.images.")])
|
||||
if not img_keys and "observation.image" in obs:
|
||||
img_keys = ["observation.image"]
|
||||
if img_keys:
|
||||
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
|
||||
video = np.stack(cams, axis=1) # (B, V, H, W, C)
|
||||
video = np.expand_dims(video, axis=1) # (B, 1, V, H, W, C)
|
||||
# GR00T validates that video.shape[3] == 3 (channels), so reorder to (B, T, V, C, H, W)
|
||||
video = np.transpose(video, (0, 1, 2, 5, 3, 4)) # (B, 1, V, C, H, W)
|
||||
obs["video"] = video
|
||||
# Drop raw images to avoid confusion downstream
|
||||
for k in img_keys:
|
||||
obs.pop(k, None)
|
||||
|
||||
# 2) Language (string)
|
||||
lang = comp.get(self.language_key)
|
||||
if isinstance(lang, list):
|
||||
lang = lang[0] if len(lang) > 0 else None
|
||||
if not lang:
|
||||
lang = "Perform the task."
|
||||
if self.formalize_language:
|
||||
lang = (lang or "").lower()
|
||||
lang = "".join(ch for ch in lang if ch.isalnum() or ch.isspace())
|
||||
comp["language"] = lang
|
||||
|
||||
# 3) State/state_mask -> (B, 1, max_state_dim)
|
||||
if "observation.state" in obs:
|
||||
state = obs["observation.state"] # (B, D)
|
||||
if state.dim() != 2:
|
||||
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
|
||||
bsz, d = state.shape
|
||||
# Normalize BEFORE padding
|
||||
if self.normalize_min_max:
|
||||
state = _min_max_norm(state, "observation.state")
|
||||
state = state.unsqueeze(1) # (B, 1, D)
|
||||
if d > self.max_state_dim:
|
||||
state = state[:, :, : self.max_state_dim]
|
||||
d = self.max_state_dim
|
||||
elif d < self.max_state_dim:
|
||||
pad = torch.zeros(bsz, 1, self.max_state_dim - d, dtype=state.dtype, device=state.device)
|
||||
state = torch.cat([state, pad], dim=2)
|
||||
state_mask = torch.zeros(bsz, 1, self.max_state_dim, dtype=torch.bool, device=state.device)
|
||||
state_mask[:, :, :d] = True
|
||||
obs["state"] = state
|
||||
obs["state_mask"] = state_mask
|
||||
|
||||
# 4) Action/action_mask -> (B, action_horizon, max_action_dim)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
# Normalize BEFORE temporal expansion/padding
|
||||
if self.normalize_min_max:
|
||||
if action.dim() == 2:
|
||||
action = _min_max_norm(action, "action")
|
||||
elif action.dim() == 3:
|
||||
b, t, d = action.shape
|
||||
flat = action.reshape(b * t, d)
|
||||
flat = _min_max_norm(flat, "action")
|
||||
action = flat.view(b, t, d)
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
|
||||
elif action.dim() == 3:
|
||||
b, t, d = action.shape
|
||||
if t < self.action_horizon:
|
||||
last = action[:, -1:, :]
|
||||
pad = last.repeat(1, self.action_horizon - t, 1)
|
||||
action = torch.cat([action, pad], dim=1)
|
||||
elif t > self.action_horizon:
|
||||
action = action[:, : self.action_horizon, :]
|
||||
else:
|
||||
raise ValueError(f"action must be (B, D) or (B, T, D), got {tuple(action.shape)}")
|
||||
|
||||
b, t, d = action.shape
|
||||
if d > self.max_action_dim:
|
||||
action = action[:, :, : self.max_action_dim]
|
||||
d = self.max_action_dim
|
||||
elif d < self.max_action_dim:
|
||||
pad = torch.zeros(b, t, self.max_action_dim - d, dtype=action.dtype, device=action.device)
|
||||
action = torch.cat([action, pad], dim=2)
|
||||
action_mask = torch.zeros(b, t, self.max_action_dim, dtype=torch.bool, device=action.device)
|
||||
action_mask[:, :, :d] = True
|
||||
transition[TransitionKey.ACTION] = action
|
||||
comp["action_mask"] = action_mask
|
||||
|
||||
# 5) Embodiment id as LongTensor (B,)
|
||||
emb_id = self.embodiment_mapping.get(self.embodiment_tag, 0)
|
||||
# Infer batch size/device from any tensor in obs or action
|
||||
bsz = None
|
||||
device = torch.device("cpu")
|
||||
for v in list(obs.values()) + [transition.get(TransitionKey.ACTION)]:
|
||||
if isinstance(v, torch.Tensor):
|
||||
bsz = v.shape[0]
|
||||
device = v.device
|
||||
break
|
||||
if bsz is None and "video" in obs and isinstance(obs["video"], np.ndarray):
|
||||
bsz = obs["video"].shape[0]
|
||||
if bsz is None:
|
||||
bsz = 1
|
||||
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.long, device=device)
|
||||
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return transition
|
||||
|
||||
# Pipeline API requirement: declare how features change (we keep it simple)
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a serializable dictionary of the processor's configuration.
|
||||
|
||||
Excludes 'stats' since they are saved separately via state_dict().
|
||||
"""
|
||||
return {
|
||||
"state_horizon": self.state_horizon,
|
||||
"action_horizon": self.action_horizon,
|
||||
"max_state_dim": self.max_state_dim,
|
||||
"max_action_dim": self.max_action_dim,
|
||||
"language_key": self.language_key,
|
||||
"formalize_language": self.formalize_language,
|
||||
"embodiment_tag": self.embodiment_tag,
|
||||
"embodiment_mapping": self.embodiment_mapping,
|
||||
"normalize_min_max": self.normalize_min_max,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns normalization statistics as a flat state dictionary.
|
||||
|
||||
This enables saving stats to safetensors files, similar to normalizer_processor.
|
||||
"""
|
||||
if not self.stats:
|
||||
return {}
|
||||
|
||||
flat: dict[str, torch.Tensor] = {}
|
||||
for key, sub in self.stats.items():
|
||||
for stat_name, value in sub.items():
|
||||
tensor = torch.as_tensor(value).cpu()
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Loads normalization statistics from a flat state dictionary.
|
||||
|
||||
This enables loading stats from safetensors files during from_pretrained.
|
||||
"""
|
||||
if not state:
|
||||
return
|
||||
|
||||
reconstructed: dict[str, dict[str, Any]] = {}
|
||||
for flat_key, tensor in state.items():
|
||||
if "." in flat_key:
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
if key not in reconstructed:
|
||||
reconstructed[key] = {}
|
||||
reconstructed[key][stat_name] = tensor
|
||||
|
||||
if reconstructed:
|
||||
self.stats = reconstructed
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_eagle_encode_v3")
|
||||
class GrootEagleEncodeStep(ProcessorStep):
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def proc(self) -> ProcessorMixin:
|
||||
if self._proc is None:
|
||||
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
|
||||
return self._proc
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
|
||||
if "video" not in obs:
|
||||
return transition
|
||||
|
||||
video = obs["video"] # (B, T, V, H, W, C) uint8
|
||||
lang = comp.get("language", "Perform the task.")
|
||||
if isinstance(lang, list):
|
||||
lang = lang[0] if len(lang) > 0 else "Perform the task."
|
||||
|
||||
bsz = video.shape[0]
|
||||
eagle_contents: list[dict[str, Any]] = []
|
||||
for b in range(bsz):
|
||||
vt = video[b] # (T, V, C, H, W) after reorder
|
||||
if vt.ndim != 5:
|
||||
# Fallback: assume (T, V, H, W, C)
|
||||
t, v, h, w, c = vt.shape
|
||||
flat = rearrange(vt, "t v h w c -> (t v) h w c")
|
||||
else:
|
||||
t, v, c, h, w = vt.shape
|
||||
flat = rearrange(vt, "t v c h w -> (t v) h w c")
|
||||
images = [Image.fromarray(flat[i]) for i in range(t * v)]
|
||||
# Format language as string list representation to match Original GROOT
|
||||
lang_formatted = str([lang])
|
||||
text_content = [{"type": "text", "text": lang_formatted}]
|
||||
image_content = [{"type": "image", "image": img} for img in images]
|
||||
conv = [{"role": "user", "content": image_content + text_content}]
|
||||
text_list = [self.proc.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)]
|
||||
img_inputs, vid_inputs = self.proc.process_vision_info(conv)
|
||||
eagle_contents.append(
|
||||
{
|
||||
"text_list": text_list,
|
||||
"image_inputs": img_inputs,
|
||||
"video_inputs": vid_inputs,
|
||||
}
|
||||
)
|
||||
|
||||
comp["eagle_content"] = eagle_contents
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return transition
|
||||
|
||||
# Pipeline API requirement: declare how features change (no schema change here)
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
# Original GR00T-style collate: converts eagle_content -> eagle_* tensors
|
||||
def collate(features: list[dict[str, Any]], eagle_processor: ProcessorMixin) -> dict[str, Any]:
|
||||
batch: dict[str, Any] = {}
|
||||
keys = features[0].keys()
|
||||
|
||||
for key in keys:
|
||||
values = [elem[key] for elem in features]
|
||||
|
||||
if key == "eagle_content":
|
||||
text_list: list[str] = []
|
||||
image_inputs: list[Any] = []
|
||||
for v in values:
|
||||
curr_text_list = v["text_list"]
|
||||
curr_image_inputs = v["image_inputs"]
|
||||
text_list += curr_text_list
|
||||
image_inputs += curr_image_inputs
|
||||
eagle_inputs = eagle_processor(
|
||||
text=text_list,
|
||||
images=image_inputs,
|
||||
images_kwargs={"min_dynamic_tiles": 1, "max_dynamic_tiles": 1, "use_thumbnail": False},
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
for k, v in eagle_inputs.items():
|
||||
k = "eagle_" + k
|
||||
batch[k] = v
|
||||
elif key in ("pixel_values", "image_grid_thw", "attention_mask", "input_ids"):
|
||||
# Concat in existing batch dimension.
|
||||
batch[key] = torch.cat(values)
|
||||
else:
|
||||
# state, state_mask, action and action_mask.
|
||||
# Stack to form the batch dimension.
|
||||
batch[key] = torch.from_numpy(np.stack(values))
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_eagle_collate_v3")
|
||||
class GrootEagleCollateStep(ProcessorStep):
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def proc(self) -> ProcessorMixin:
|
||||
if self._proc is None:
|
||||
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
|
||||
return self._proc
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
contents = comp.get("eagle_content")
|
||||
if not contents:
|
||||
return transition
|
||||
|
||||
# Build features list as original API expects: one dict per batch item
|
||||
features = [{"eagle_content": content} for content in contents]
|
||||
batched = collate(features, self.proc)
|
||||
|
||||
# Inject eagle_* tensors and remove the temporary content and raw video to free memory
|
||||
for k, v in batched.items():
|
||||
comp[k] = v
|
||||
comp.pop("eagle_content", None)
|
||||
obs.pop(
|
||||
"video", None
|
||||
) # The video has been fully encoded into eagle_* tensors, so we don't need the raw video anymore
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
|
||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
env_action_dim: int = 0
|
||||
# Apply inverse of min-max normalization if it was used in preprocessor
|
||||
normalize_min_max: bool = True
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if not isinstance(action, torch.Tensor):
|
||||
return transition
|
||||
|
||||
# Select last timestep and slice to env dimension
|
||||
if action.dim() == 3:
|
||||
action = action[:, -1, :]
|
||||
# Now action is (B, D_model)
|
||||
if self.env_action_dim and action.shape[-1] >= self.env_action_dim:
|
||||
action = action[..., : self.env_action_dim]
|
||||
|
||||
# Inverse min-max normalization mirroring _min_max_norm:
|
||||
# forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0
|
||||
# inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min
|
||||
if self.normalize_min_max and self.stats is not None:
|
||||
stats_k = self.stats.get("action", {})
|
||||
d = action.shape[-1]
|
||||
min_v = torch.as_tensor(
|
||||
stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device
|
||||
)
|
||||
max_v = torch.as_tensor(
|
||||
stats_k.get("max", torch.ones(d)), dtype=action.dtype, device=action.device
|
||||
)
|
||||
if min_v.numel() != d:
|
||||
min_v = torch.nn.functional.pad(min_v.flatten()[:d], (0, max(0, d - min_v.numel())))
|
||||
min_v = min_v.to(action.device, dtype=action.dtype)
|
||||
if max_v.numel() != d:
|
||||
max_v = torch.nn.functional.pad(max_v.flatten()[:d], (0, max(0, d - max_v.numel())))
|
||||
max_v = max_v.to(action.device, dtype=action.dtype)
|
||||
denom = max_v - min_v
|
||||
mask = denom != 0
|
||||
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
|
||||
inv = (action + 1.0) * 0.5 * safe_denom + min_v
|
||||
action = torch.where(mask, inv, min_v)
|
||||
|
||||
transition[TransitionKey.ACTION] = action
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a serializable dictionary of the processor's configuration.
|
||||
|
||||
Excludes 'stats' since they are saved separately via state_dict().
|
||||
"""
|
||||
return {
|
||||
"env_action_dim": self.env_action_dim,
|
||||
"normalize_min_max": self.normalize_min_max,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns normalization statistics as a flat state dictionary.
|
||||
|
||||
This enables saving stats to safetensors files, similar to normalizer_processor.
|
||||
"""
|
||||
if not self.stats:
|
||||
return {}
|
||||
|
||||
flat: dict[str, torch.Tensor] = {}
|
||||
for key, sub in self.stats.items():
|
||||
for stat_name, value in sub.items():
|
||||
tensor = torch.as_tensor(value).cpu()
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Loads normalization statistics from a flat state dictionary.
|
||||
|
||||
This enables loading stats from safetensors files during from_pretrained.
|
||||
"""
|
||||
if not state:
|
||||
return
|
||||
|
||||
reconstructed: dict[str, dict[str, Any]] = {}
|
||||
for flat_key, tensor in state.items():
|
||||
if "." in flat_key:
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
if key not in reconstructed:
|
||||
reconstructed[key] = {}
|
||||
reconstructed[key][stat_name] = tensor
|
||||
|
||||
if reconstructed:
|
||||
self.stats = reconstructed
|
||||
47
src/lerobot/policies/groot/utils.py
Normal file
47
src/lerobot/policies/groot/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
from shutil import copytree
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
|
||||
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
|
||||
|
||||
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
|
||||
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
vendor_dir = Path(vendor_dir)
|
||||
|
||||
try:
|
||||
# Populate/refresh cache with vendor files to ensure a complete processor directory
|
||||
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
|
||||
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
|
||||
|
||||
required_assets = [
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"added_tokens.json",
|
||||
"chat_template.json",
|
||||
"special_tokens_map.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
|
||||
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
|
||||
|
||||
for fname in required_assets:
|
||||
dst = cache_dir / fname
|
||||
if not dst.exists():
|
||||
print(f"[GROOT] Fetching {fname}")
|
||||
hf_hub_download(
|
||||
repo_id=assets_repo,
|
||||
filename=fname,
|
||||
repo_type="model",
|
||||
local_dir=str(cache_dir),
|
||||
)
|
||||
@@ -505,14 +505,17 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": str(policy.config.device)},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(policy.config.device)},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
|
||||
@@ -62,6 +62,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
|
||||
|
||||
_transformers_available = is_package_available("transformers")
|
||||
_peft_available = is_package_available("peft")
|
||||
|
||||
|
||||
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
|
||||
|
||||
207
tests/policies/groot/test_groot_lerobot.py
Normal file
207
tests/policies/groot/test_groot_lerobot.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script for LeRobot's Groot policy forward and inference passes."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.utils import auto_select_torch_device
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local Groot installation and is not meant for CI",
|
||||
)
|
||||
|
||||
|
||||
# Define constants for dummy data
|
||||
DUMMY_STATE_DIM = 44
|
||||
DUMMY_ACTION_DIM = 44
|
||||
DUMMY_ACTION_HORIZON = 16
|
||||
IMAGE_SIZE = 256
|
||||
DEVICE = auto_select_torch_device()
|
||||
MODEL_PATH = "aractingi/bimanual-handover-groot-10k"
|
||||
|
||||
|
||||
def cleanup_memory():
|
||||
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
|
||||
print("\nCleaning up memory...")
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
print("Memory cleanup complete.")
|
||||
|
||||
|
||||
def set_seed_all(seed: int):
|
||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Set deterministic behavior
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
|
||||
def instantiate_lerobot_groot(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH,
|
||||
) -> tuple[
|
||||
GrootPolicy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = GrootPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
policy.config.embodiment_tag = "gr1"
|
||||
else:
|
||||
config = GrootConfig(
|
||||
base_model_path=model_path,
|
||||
n_action_steps=DUMMY_ACTION_HORIZON,
|
||||
chunk_size=DUMMY_ACTION_HORIZON,
|
||||
image_size=[IMAGE_SIZE, IMAGE_SIZE],
|
||||
device=DEVICE,
|
||||
embodiment_tag="gr1",
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config=policy.config,
|
||||
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
|
||||
)
|
||||
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create a dummy data batch for testing."""
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red cube and place it in the bin"
|
||||
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
|
||||
|
||||
batch = {
|
||||
"observation.state": state,
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=device, # Action ground truth (for training)
|
||||
),
|
||||
"observation.images.ego_view": torch.rand(
|
||||
batch_size,
|
||||
3,
|
||||
IMAGE_SIZE,
|
||||
IMAGE_SIZE,
|
||||
dtype=torch.float32,
|
||||
device=device, # Images in [0, 1] range as expected by LeRobot
|
||||
),
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_inference():
|
||||
"""Test the inference pass (select_action) of LeRobot's Groot policy."""
|
||||
print("Test: LeRobot Groot Inference Pass")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
# Instantiate policy and processors
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
|
||||
from_pretrained=True
|
||||
)
|
||||
batch = create_dummy_data()
|
||||
|
||||
print("\n[LeRobot] Running inference...")
|
||||
lerobot_policy.eval()
|
||||
batch_lerobot_processed = lerobot_preprocessor(deepcopy(batch))
|
||||
|
||||
# Ensure identical RNG state before inference
|
||||
torch.manual_seed(42)
|
||||
|
||||
with torch.no_grad():
|
||||
lerobot_action = lerobot_policy.select_action(batch_lerobot_processed)
|
||||
|
||||
print(f"\nInference successful. Output action shape: {lerobot_action.shape}")
|
||||
print("Output actions (first 5 dims):")
|
||||
print(lerobot_action[:, :5])
|
||||
|
||||
lerobot_action = lerobot_postprocessor(lerobot_action)
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor, batch
|
||||
cleanup_memory()
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_forward_pass():
|
||||
"""Test the forward pass of LeRobot's Groot policy."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Test: LeRobot Groot Forward Pass (Training Mode)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
# Instantiate policy and processors
|
||||
lerobot_policy, lerobot_preprocessor, _ = instantiate_lerobot_groot(from_pretrained=True)
|
||||
batch = create_dummy_data()
|
||||
|
||||
lerobot_policy.eval()
|
||||
|
||||
print("\n[LeRobot] Running forward pass...")
|
||||
batch_lerobot_processed = lerobot_preprocessor(deepcopy(batch))
|
||||
|
||||
set_seed_all(42)
|
||||
with torch.no_grad():
|
||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||
|
||||
print("\nForward pass successful.")
|
||||
print(f" - Loss: {lerobot_loss.item():.6f}")
|
||||
print(f" - Metrics: {lerobot_metrics}")
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, batch
|
||||
cleanup_memory()
|
||||
443
tests/policies/groot/test_groot_vs_original.py
Normal file
443
tests/policies/groot/test_groot_vs_original.py
Normal file
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify Groot policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
pytest.importorskip("gr00t")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local Groot installation and is not meant for CI",
|
||||
)
|
||||
|
||||
|
||||
from gr00t.data.dataset import ModalityConfig # noqa: E402
|
||||
from gr00t.data.embodiment_tags import EmbodimentTag # noqa: E402
|
||||
from gr00t.data.transform.base import ComposedModalityTransform # noqa: E402
|
||||
from gr00t.model.policy import Gr00tPolicy # noqa: E402
|
||||
|
||||
# GR1 humanoid dimensions (from pretrained model metadata)
|
||||
# The actual GR1 robot has 44 dimensions for both state and action
|
||||
# GR00TTransform will pad state to 64 and truncate action to 32
|
||||
DUMMY_STATE_DIM = 44
|
||||
DUMMY_ACTION_DIM = 44
|
||||
DUMMY_ACTION_HORIZON = 16
|
||||
IMAGE_SIZE = 256
|
||||
DEVICE = "cpu"
|
||||
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
GR1_BODY_PARTS = {
|
||||
"left_arm": 7,
|
||||
"left_hand": 6,
|
||||
"left_leg": 6,
|
||||
"neck": 3,
|
||||
"right_arm": 7,
|
||||
"right_hand": 6,
|
||||
"right_leg": 6,
|
||||
"waist": 3,
|
||||
}
|
||||
|
||||
|
||||
def cleanup_memory():
|
||||
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
|
||||
print("\nCleaning up memory...")
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
print("Memory cleanup complete.")
|
||||
|
||||
|
||||
def set_seed_all(seed: int):
|
||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Set deterministic behavior
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
|
||||
def instantiate_lerobot_groot(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH,
|
||||
) -> tuple[
|
||||
GrootPolicy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = GrootPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
policy.config.embodiment_tag = "gr1"
|
||||
else:
|
||||
config = GrootConfig(
|
||||
base_model_path=model_path,
|
||||
n_action_steps=DUMMY_ACTION_HORIZON,
|
||||
chunk_size=DUMMY_ACTION_HORIZON,
|
||||
image_size=[IMAGE_SIZE, IMAGE_SIZE],
|
||||
device=DEVICE,
|
||||
embodiment_tag="gr1",
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config=policy.config,
|
||||
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
|
||||
)
|
||||
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_groot(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH,
|
||||
):
|
||||
"""Instantiate original Groot policy from NVIDIA's implementation."""
|
||||
from gr00t.data.transform.concat import ConcatTransform
|
||||
from gr00t.data.transform.state_action import StateActionToTensor
|
||||
from gr00t.data.transform.video import VideoToNumpy, VideoToTensor
|
||||
from gr00t.model.transforms import GR00TTransform
|
||||
|
||||
video_keys = ["video.ego_view"]
|
||||
state_keys = [
|
||||
"state"
|
||||
] # Important: Use single concatenated "state" key (not split body parts) to match preprocessing
|
||||
action_keys = [
|
||||
"action.left_arm",
|
||||
"action.right_arm",
|
||||
"action.left_hand",
|
||||
"action.right_hand",
|
||||
"action.left_leg",
|
||||
"action.right_leg",
|
||||
"action.neck",
|
||||
"action.waist",
|
||||
]
|
||||
language_keys = ["annotation.human.action.task_description"]
|
||||
|
||||
modality_config = {
|
||||
"video": ModalityConfig(
|
||||
delta_indices=[0], # Current frame only
|
||||
modality_keys=video_keys,
|
||||
),
|
||||
"state": ModalityConfig(
|
||||
delta_indices=[0],
|
||||
modality_keys=state_keys,
|
||||
),
|
||||
"action": ModalityConfig(
|
||||
delta_indices=list(range(DUMMY_ACTION_HORIZON)),
|
||||
modality_keys=action_keys,
|
||||
),
|
||||
"language": ModalityConfig(
|
||||
delta_indices=[0],
|
||||
modality_keys=language_keys,
|
||||
),
|
||||
}
|
||||
|
||||
modality_transform = ComposedModalityTransform(
|
||||
transforms=[
|
||||
VideoToTensor(apply_to=video_keys),
|
||||
VideoToNumpy(apply_to=video_keys), # Convert to numpy (GR00TTransform expects numpy arrays)
|
||||
# State is already a single concatenated key, so no StateActionToTensor needed
|
||||
# Convert action from numpy to tensor
|
||||
StateActionToTensor(apply_to=action_keys),
|
||||
# Concatenate only video and actions (state is already single key)
|
||||
ConcatTransform(
|
||||
video_concat_order=video_keys,
|
||||
state_concat_order=[], # Empty:state is already single key
|
||||
action_concat_order=action_keys,
|
||||
),
|
||||
GR00TTransform(
|
||||
max_state_dim=64,
|
||||
max_action_dim=32,
|
||||
state_horizon=1,
|
||||
action_horizon=DUMMY_ACTION_HORIZON,
|
||||
training=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy = Gr00tPolicy(
|
||||
model_path=model_path,
|
||||
embodiment_tag=EmbodimentTag.GR1,
|
||||
modality_config=modality_config,
|
||||
modality_transform=modality_transform,
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
return policy, modality_config, modality_transform
|
||||
|
||||
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create dummy data for testing both implementations."""
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red cube and place it in the bin"
|
||||
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
|
||||
|
||||
batch = {
|
||||
"observation.state": state,
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=device, # Action ground truth (for training)
|
||||
),
|
||||
"observation.images.ego_view": torch.rand(
|
||||
batch_size,
|
||||
3,
|
||||
IMAGE_SIZE,
|
||||
IMAGE_SIZE,
|
||||
dtype=torch.float32,
|
||||
device=device, # Images in [0, 1] range as expected by LeRobot
|
||||
),
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def convert_lerobot_to_original_format(batch, modality_config):
|
||||
"""Convert LeRobot batch format to original Groot format.
|
||||
|
||||
The original Groot expects observations in this format:
|
||||
{
|
||||
"video.<camera_name>": np.ndarray (T, H, W, C) or (B, T, H, W, C)
|
||||
"state.<state_component>": np.ndarray (T, D) or (B, T, D)
|
||||
"action.<action_component>": np.ndarray (T, D) or (B, T, D)
|
||||
"annotation.<annotation_type>": str or list[str]
|
||||
}
|
||||
"""
|
||||
# Original Groot expects (T, H, W, C) format for images
|
||||
# LeRobot has (B, C, H, W) format, so we need to convert
|
||||
observation = {}
|
||||
|
||||
for img_key in ["ego_view"]:
|
||||
lerobot_key = f"observation.images.{img_key}"
|
||||
if lerobot_key in batch:
|
||||
img = batch[lerobot_key]
|
||||
# Convert from (B, C, H, W) to (B, T=1, H, W, C)
|
||||
img_np = img.permute(0, 2, 3, 1).unsqueeze(1).cpu().numpy()
|
||||
# Convert [0, 1] to [0, 255] uint8 as expected by original
|
||||
img_np = (img_np * 255).astype(np.uint8)
|
||||
observation[f"video.{img_key}"] = img_np
|
||||
|
||||
# Important: The Original's GR00TTransform expects "state" as (B, T, D), not split body parts
|
||||
if "observation.state" in batch:
|
||||
state = batch["observation.state"]
|
||||
state_np = state.unsqueeze(1).cpu().numpy() # (B, 1, D)
|
||||
observation["state"] = state_np
|
||||
|
||||
if "action" in batch:
|
||||
action = batch["action"]
|
||||
action_np = action.cpu().numpy()
|
||||
|
||||
start_idx = 0
|
||||
for part_name, part_dim in GR1_BODY_PARTS.items():
|
||||
end_idx = start_idx + part_dim
|
||||
observation[f"action.{part_name}"] = action_np[:, :, start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
if "task" in batch:
|
||||
task_list = batch["task"]
|
||||
# GR00TTransform expects language with (B, T) shape for batched data
|
||||
# Create a (B, T=1) array where each element is the string directly
|
||||
bsz = len(task_list)
|
||||
task_array = np.empty((bsz, 1), dtype=object)
|
||||
for i in range(bsz):
|
||||
task_array[i, 0] = task_list[i] # Assign string directly to each (i, 0) position
|
||||
observation["annotation.human.action.task_description"] = task_array
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def test_groot_original_vs_lerobot_pretrained():
|
||||
"""Test Groot original implementation vs LeRobot implementation with pretrained weights."""
|
||||
print("Test: Groot Original vs LeRobot with Pretrained Weights (Inference)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
|
||||
from_pretrained=True
|
||||
)
|
||||
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
|
||||
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
print("\n[LeRobot] Running inference...")
|
||||
lerobot_policy.eval()
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
|
||||
# Important: Reset seed immediately before inference to ensure identical RNG state
|
||||
torch.manual_seed(42)
|
||||
|
||||
with torch.no_grad():
|
||||
lerobot_actions = lerobot_policy.select_action(batch_lerobot_processed)
|
||||
|
||||
print("\n[Original] Running inference...")
|
||||
original_policy.model.eval()
|
||||
observation = convert_lerobot_to_original_format(batch, modality_config)
|
||||
original_obs_transformed = modality_transform(deepcopy(observation))
|
||||
|
||||
# Important: Reset seed immediately before inference to ensure identical RNG state
|
||||
torch.manual_seed(42)
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_output = original_policy.model.get_action(original_obs_transformed)
|
||||
original_actions_raw = original_model_output["action_pred"] # [2, 16, 32]
|
||||
# Take first timestep
|
||||
original_actions = original_actions_raw[:, 0, :].to(lerobot_actions.device).to(lerobot_actions.dtype)
|
||||
|
||||
print("Action Comparison:")
|
||||
diff = lerobot_actions - original_actions
|
||||
abs_diff = torch.abs(diff)
|
||||
|
||||
for batch_idx in range(lerobot_actions.shape[0]):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Batch {batch_idx}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Idx':<5} {'LeRobot':<14} {'Original':<14} {'Difference':<14}")
|
||||
print("-" * 60)
|
||||
for action_idx in range(lerobot_actions.shape[1]):
|
||||
lr_val = lerobot_actions[batch_idx, action_idx].item()
|
||||
orig_val = original_actions[batch_idx, action_idx].item()
|
||||
diff_val = abs(lr_val - orig_val)
|
||||
sign = "+" if (lr_val - orig_val) > 0 else "-"
|
||||
print(f"{action_idx:<5} {lr_val:>13.6f} {orig_val:>13.6f} {sign}{diff_val:>12.6f}")
|
||||
|
||||
max_diff = abs_diff.max().item()
|
||||
tolerance = 0.001
|
||||
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
|
||||
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6f}"
|
||||
)
|
||||
print(f"\nSuccess: Actions match within tolerance ({tolerance})!")
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
|
||||
del original_policy, modality_config, modality_transform
|
||||
del batch, batch_lerobot, observation
|
||||
cleanup_memory()
|
||||
|
||||
|
||||
def test_groot_forward_pass_comparison():
|
||||
"""Test forward pass comparison between LeRobot and Original Groot implementations."""
|
||||
print("Test: Forward Pass Comparison (Training Mode)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
|
||||
from_pretrained=True
|
||||
)
|
||||
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
|
||||
|
||||
batch = create_dummy_data()
|
||||
lerobot_policy.eval()
|
||||
original_policy.model.eval()
|
||||
|
||||
print("\n[LeRobot] Running forward pass...")
|
||||
batch_lerobot = deepcopy(batch)
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
|
||||
set_seed_all(42)
|
||||
with torch.no_grad():
|
||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||
|
||||
print(f" Loss: {lerobot_loss.item():.6f}")
|
||||
|
||||
print("\n[Original] Running forward pass...")
|
||||
observation = convert_lerobot_to_original_format(batch, modality_config)
|
||||
transformed_obs = modality_transform(observation)
|
||||
|
||||
if "action" not in transformed_obs:
|
||||
action_for_forward = batch_lerobot_processed["action"]
|
||||
action_mask_for_forward = batch_lerobot_processed["action_mask"]
|
||||
|
||||
# Match action horizon if needed
|
||||
if action_for_forward.shape[1] != original_policy.model.action_horizon:
|
||||
if action_for_forward.shape[1] < original_policy.model.action_horizon:
|
||||
pad_size = original_policy.model.action_horizon - action_for_forward.shape[1]
|
||||
last_action = action_for_forward[:, -1:, :]
|
||||
padding = last_action.repeat(1, pad_size, 1)
|
||||
action_for_forward = torch.cat([action_for_forward, padding], dim=1)
|
||||
|
||||
mask_padding = torch.zeros(
|
||||
action_mask_for_forward.shape[0],
|
||||
pad_size,
|
||||
action_mask_for_forward.shape[2],
|
||||
dtype=action_mask_for_forward.dtype,
|
||||
device=action_mask_for_forward.device,
|
||||
)
|
||||
action_mask_for_forward = torch.cat([action_mask_for_forward, mask_padding], dim=1)
|
||||
else:
|
||||
action_for_forward = action_for_forward[:, : original_policy.model.action_horizon, :]
|
||||
action_mask_for_forward = action_mask_for_forward[
|
||||
:, : original_policy.model.action_horizon, :
|
||||
]
|
||||
|
||||
transformed_obs["action"] = action_for_forward
|
||||
transformed_obs["action_mask"] = action_mask_for_forward
|
||||
|
||||
set_seed_all(42)
|
||||
with torch.no_grad():
|
||||
original_outputs = original_policy.model.forward(transformed_obs)
|
||||
|
||||
original_loss = original_outputs["loss"]
|
||||
print(f" Loss: {original_loss.item():.6f}")
|
||||
|
||||
loss_diff = abs(lerobot_loss.item() - original_loss.item())
|
||||
loss_rel_diff = loss_diff / (abs(original_loss.item()) + 1e-8) * 100
|
||||
|
||||
print("\nLoss Values:")
|
||||
print(f" LeRobot: {lerobot_loss.item():.6f}")
|
||||
print(f" Original: {original_loss.item():.6f}")
|
||||
print(f" Absolute difference: {loss_diff:.6f}")
|
||||
print(f" Relative difference: {loss_rel_diff:.2f}%")
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
|
||||
del original_policy, modality_config, modality_transform
|
||||
del batch, batch_lerobot, observation, transformed_obs
|
||||
cleanup_memory()
|
||||
Reference in New Issue
Block a user