Compare commits

..

1 Commits

Author SHA1 Message Date
fracapuano
241e7076f2 add: async inference stack 2025-06-03 18:03:42 +02:00
25 changed files with 1570 additions and 1655 deletions

View File

@@ -40,24 +40,24 @@ jobs:
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push CPU
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/lerobot-cpu/Dockerfile
@@ -78,24 +78,24 @@ jobs:
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/lerobot-gpu/Dockerfile
@@ -110,23 +110,23 @@ jobs:
group: aws-general-8-plus
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Login to DockerHub
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU dev
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/lerobot-gpu-dev/Dockerfile

View File

@@ -33,7 +33,7 @@ jobs:
runs-on:
group: aws-general-8-plus
container:
image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
image: huggingface/lerobot-cpu:latest
options: --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -60,7 +60,7 @@ jobs:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu"
container:
image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
image: huggingface/lerobot-gpu:latest
options: --gpus all --shm-size "16gb"
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}

View File

@@ -33,12 +33,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
uses: actions/setup-python@v4
with:
python-version: ${{ env.PYTHON_VERSION }}
@@ -64,9 +64,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
persist-credentials: false
- name: typos-action
uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10
uses: crate-ci/typos@v1.29.10

View File

@@ -35,7 +35,7 @@ jobs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Check out code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
persist-credentials: false
@@ -64,17 +64,17 @@ jobs:
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
uses: docker/setup-buildx-action@v3
with:
cache-binary: false
- name: Check out code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Build Docker image
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
uses: docker/build-push-action@v5
with:
file: ${{ matrix.docker-file }}
context: .

View File

@@ -50,7 +50,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -62,7 +62,7 @@ jobs:
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
@@ -85,7 +85,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -94,7 +94,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install uv and python
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
version: ${{ env.UV_VERSION }}
@@ -117,7 +117,7 @@ jobs:
env:
MUJOCO_GL: egl
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
persist-credentials: false
@@ -129,7 +129,7 @@ jobs:
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
- name: Install uv and python
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
version: ${{ env.UV_VERSION }}

View File

@@ -24,12 +24,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Secret Scanning
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
uses: trufflesecurity/trufflehog@main
with:
extra_args: --only-verified

View File

@@ -37,18 +37,18 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.32.0
rev: v1.31.1
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.20.0
rev: v3.19.1
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.11
rev: v0.11.5
hooks:
- id: ruff
args: [--fix]
@@ -57,12 +57,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.26.0
rev: v8.24.3
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.8.0
rev: v1.5.2
hooks:
- id: zizmor

View File

@@ -10,8 +10,3 @@
- local: getting_started_real_world_robot
title: Getting Started with Real-World Robots
title: "Tutorials"
- sections:
- local: smolvla
title: Use SmolVLA
title: "Policies"

View File

@@ -1,91 +0,0 @@
# Use SmolVLA
SmolVLA is designed to be easy to use and integrate—whether you're finetuning on your own data or plugging it into an existing robotics stack.
<p align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/640e21ef3c82bd463ee5a76d/aooU0a3DMtYmy_1IWMaIM.png" alt="SmolVLA architecture." width="500"/>
<br/>
<em>Figure 2. SmolVLA takes as input a sequence of RGB images from multiple cameras, the robots current sensorimotor state, and a natural language instruction. The VLM encodes these into contextual features, which condition the action expert to generate a continuous sequence of actions.</em>
</p>
### Install
First, install the required dependencies:
```python
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e ".[smolvla]"
```
### Finetune the pretrained model
Use [`smolvla_base`](https://hf.co/lerobot/smolvla_base), our pretrained 450M model, with the lerobot training framework:
```python
python lerobot/scripts/train.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=lerobot/svla_so100_stacking \
--batch_size=64 \
--steps=200000
```
<p align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/640e21ef3c82bd463ee5a76d/S-3vvVCulChREwHDkquoc.gif" alt="Comparison of SmolVLA across task variations." width="500"/>
<br/>
<em>Figure 1: Comparison of SmolVLA across task variations. From left to right: (1) asynchronous pick-place cube counting, (2) synchronous pick-place cube counting, (3) pick-place cube counting under perturbations, and (4) generalization on pick-and-place of the lego block with real-world SO101.</em>
</p>
### Train from scratch
If you'd like to build from the architecture (pretrained VLM + action expert) rather than a pretrained checkpoint:
```python
python lerobot/scripts/train.py \
--policy.type=smolvla \
--dataset.repo_id=lerobot/svla_so100_stacking \
--batch_size=64 \
--steps=200000
```
You can also load `SmolVLAPolicy` directly:
```python
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
```
## Evaluate the pretrained policy and run it in real-time
If you want to record the evaluation process and safe the videos on the hub, login to your HF account by running:
```python
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```python
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Now, indicate the path to the policy, which is `lerobot/smolvla_base` in this case, and run:
```python
python lerobot/scripts/control_robot.py \
--robot.type=so100 \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/eval_svla_base_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=lerobot/smolvla_base
```
Depending on your evaluation setup, you can configure the duration and the number of episodes to record for your evaluation suite.

View File

@@ -168,7 +168,12 @@ available_datasets = sorted(
)
# lists all available policies from `lerobot/common/policies`
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
available_policies = [
"act",
"diffusion",
"tdmpc",
"vqbet",
]
# lists all available robots from `lerobot/common/robot_devices/robots`
available_robots = [

View File

@@ -15,6 +15,5 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@@ -27,7 +27,6 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
@@ -60,10 +59,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "smolvla":
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -81,8 +76,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -1,154 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@PreTrainedConfig.register_subclass("smolvla")
@dataclass
class SmolVLAConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512)
# Add empty images. Used by smolvla_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = True
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
attention_mode: str = "cross_attn"
prefix_length: int = -1
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers)
self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return [0]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,801 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SmolVLA:
[Paper](https://huggingface.co/papers/2506.01844)
Designed by Hugging Face.
Install smolvla extra dependencies:
```bash
pip install -e ".[smolvla]"
```
Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
python lerobot/scripts/train.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
and an action expert.
```bash
python lerobot/scripts/train.py \
--policy.type=smolvla \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
--steps=200000
```
Example of using the smolvla pretrained model outside LeRobot training framework:
```python
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
```
"""
import math
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.common.policies.utils import (
populate_queues,
)
from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def resize_with_pad(img, width, height, pad_value=-1):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with smolvla which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class SmolVLAPolicy(PreTrainedPolicy):
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
config_class = SmolVLAConfig
name = "smolvla"
def __init__(
self,
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For backward pass
loss_dict["loss"] = loss
return loss, loss_dict
def prepare_images(self, batch):
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
"""
images = []
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expacted by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
if f"{key}_padding_mask" in batch:
mask = batch[f"{key}_padding_mask"].bool()
else:
mask = torch.ones(bsize, dtype=torch.bool, device=device)
images.append(img)
img_masks.append(mask)
# Create image features not present in the batch
# as fully 0 padded images.
for num_empty_cameras in range(len(missing_img_keys)):
if num_empty_cameras >= self.config.empty_cameras:
break
img = torch.ones_like(img) * -1
mask = torch.zeros_like(mask)
images.append(img)
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_ROBOT].device
tasks = batch["task"]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
"""Pad state"""
state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT]
state = pad_vector(state, self.config.max_state_dim)
return state
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def pad_tensor(tensor, max_len, pad_value=0):
"""
Efficiently pads a tensor along sequence dimension to match max_len.
Args:
tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
max_len (int): Fixed sequence length.
pad_value (int/float): Value for padding.
Returns:
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
"""
b, d = tensor.shape[:2]
# Create a padded tensor of max_len and copy the existing values
padded_tensor = torch.full(
(b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, :d] = tensor # Efficient in-place copy
return padded_tensor
class VLAFlowMatching(nn.Module):
"""
SmolVLA
[Paper]()
Designed by Hugging Face.
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌─────────┐ ┌─|────┐ │
│ | │────► │ │ │
│ | │ kv │ │ │
│ | │────► │Action│ │
│ | VLM │cache │Expert│ |
│ │ │────► | │ │
│ │ │ │ │ │
│ └▲──▲───▲─┘ └───▲──┘ |
│ │ | | │ |
│ | | | noise │
│ │ │ state │
│ │ language tokens │
│ image(s) │
└──────────────────────────────┘
"""
def __init__(self, config):
super().__init__()
self.config = config
self.vlm_with_expert = SmolVLMWithExpertModel(
model_id=self.config.vlm_model_name,
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
load_vlm_weights=self.config.load_vlm_weights,
attention_mode=self.config.attention_mode,
num_expert_layers=self.config.num_expert_layers,
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
)
self.action_time_mlp_out = nn.Linear(
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
)
self.set_requires_grad()
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
self.global_image_start_token = torch.tensor(
[self.fake_image_token, self.global_image_token], dtype=torch.long
)
self.add_image_special_tokens = self.config.add_image_special_tokens
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for SmolVLM transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
for _img_idx, (
img,
img_mask,
) in enumerate(zip(images, img_masks, strict=False)):
if self.add_image_special_tokens:
image_start_token = (
self.vlm_with_expert.embed_language_tokens(
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_start_mask = torch.ones_like(
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
)
att_masks += [0] * (image_start_mask.shape[-1])
embs.append(image_start_token)
pad_masks.append(image_start_mask)
img_emb = self.vlm_with_expert.embed_image(img)
img_emb = img_emb
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
pad_masks.append(img_mask)
att_masks += [0] * (num_img_embs)
if self.add_image_special_tokens:
image_end_token = (
self.vlm_with_expert.embed_language_tokens(
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
)
.unsqueeze(0)
.expand(img.shape[0], -1, -1)
)
image_end_mask = torch.ones_like(
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
)
embs.append(image_end_token)
pad_masks.append(image_end_mask)
att_masks += [0] * (image_end_mask.shape[1])
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
embs.append(lang_emb)
pad_masks.append(lang_masks)
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
state_emb = self.state_proj(state)
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
embs.append(state_emb)
bsize = state_emb.shape[0]
device = state_emb.device
states_seq_len = state_emb.shape[1]
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1] * (states_seq_len)
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = att_masks[None, :]
seq_len = pad_masks.shape[1]
if seq_len < self.prefix_length:
embs = pad_tensor(embs, self.prefix_length, pad_value=0)
pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
att_masks = att_masks.expand(bsize, -1)
return embs, pad_masks, att_masks
def embed_suffix(self, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
device = action_emb.device
bsize = action_emb.shape[0]
dtype = action_emb.dtype
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.vlm_with_expert.expert_hidden_size,
self.config.min_period,
self.config.max_period,
device=device,
)
time_emb = time_emb.type(dtype=dtype)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish == silu
action_time_emb = self.action_time_mlp_out(action_time_emb)
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] * self.config.chunk_size
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
(_, suffix_out), _ = self.vlm_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
# Original openpi code, upcast attention output
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.vlm_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step
x_t += dt * v_t
time += dt
return x_t
def denoise_step(
self,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
outputs_embeds, _ = self.vlm_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t

View File

@@ -1,550 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Optional
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
config = self.vlm.config
else:
config = AutoConfig.from_pretrained(model_id)
self.vlm = SmolVLMForConditionalGeneration(config=config)
self.processor = AutoProcessor.from_pretrained(model_id)
if num_vlm_layers > 0:
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
self.config = config
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
if "cross" in attention_mode:
# Reshape qkv projections to have the same input dimension as the vlm
for layer_idx in range(len(self.lm_expert.layers)):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
self.num_attention_heads = self.config.text_config.num_attention_heads
self.num_key_value_heads = self.config.text_config.num_key_value_heads
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_mode = attention_mode
self.expert_hidden_size = lm_expert_config.hidden_size
self.set_requires_grad()
def get_vlm_model(self):
return self.vlm.model
def set_requires_grad(self):
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
for params in self.get_vlm_model().vision_model.parameters():
params.requires_grad = False
if self.train_expert_only:
self.vlm.eval()
for params in self.vlm.parameters():
params.requires_grad = False
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any(k in name for k in frozen_layers):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
if "lm_head" in name:
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
if self.train_expert_only:
self.vlm.eval()
def embed_image(self, image: torch.Tensor):
patch_attention_mask = None
# Get sequence from the vision encoder
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
seq_len = query_states.shape[1]
if seq_len < position_ids.shape[1]:
_position_ids = position_ids[:, :seq_len]
_attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
_position_ids = position_ids
_attention_mask = attention_mask
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_)
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
)
return [att_output], past_key_values
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# B,L,H,D with L sequence length, H number of heads, D head dim
query_states = apply_rope(query_state, position_id)
key_states = apply_rope(key_state, position_id)
att_output = attention_interface(
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_outputs.append(att_output)
else:
expert_position_id = position_ids
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
expert_input_shape = expert_hidden_states.shape[:-1]
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
att_outputs.append(None)
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list:
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
for i in range(self.num_vlm_layers):
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
expert_layer = None
else:
expert_layer_index = i // multiple_of if multiple_of > 0 else i
expert_layer = models[1].layers[expert_layer_index]
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
continue
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
att_out = att_output[:, start:end]
out_emb = layer.self_attn.o_proj(att_out)
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end if len(att_outputs) == 1 else 0
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
attention_interface = self.eager_attention_forward
return attention_interface
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -109,10 +109,6 @@ def predict_action(observation, policy, device, use_amp):
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
# Skip all observations that are not tensors (e.g. text)
if not isinstance(observation[name], torch.Tensor):
continue
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
@@ -260,8 +256,7 @@ def control_loop(
else:
observation = robot.capture_observation()
action = None
observation["task"] = [single_task]
observation["robot_type"] = [policy.robot_type] if hasattr(policy, "robot_type") else [""]
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
@@ -272,7 +267,6 @@ def control_loop(
action = {"action": action}
if dataset is not None:
observation = {k: v for k, v in observation.items() if k not in ["task", "robot_type"]}
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)

View File

@@ -0,0 +1,60 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc StreamActions(Empty) returns (stream Action);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1;
bytes data = 2;
}
message Action {
// sent by remote Policy, to Robot
TransferState transfer_state = 1;
bytes data = 2;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@@ -0,0 +1,48 @@
# fmt: off
# flake8: noqa
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x0bPolicySetup\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xa9\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=301
_globals['_TRANSFERSTATE']._serialized_end=397
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTION']._serialized_start=127
_globals['_ACTION']._serialized_end=205
_globals['_POLICYSETUP']._serialized_start=207
_globals['_POLICYSETUP']._serialized_end=290
_globals['_EMPTY']._serialized_start=292
_globals['_EMPTY']._serialized_end=299
_globals['_ASYNCINFERENCE']._serialized_start=400
_globals['_ASYNCINFERENCE']._serialized_end=697
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,236 @@
# fmt: off
# flake8: noqa
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.StreamActions = channel.unary_stream(
'/async_inference.AsyncInference/StreamActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Action.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StreamActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'StreamActions': grpc.unary_stream_rpc_method_handler(
servicer.StreamActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Action.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def StreamActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/async_inference.AsyncInference/StreamActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Action.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,12 @@
"""Server/Client side: Sometimes you just want the environment to wait a tiny bit"""
idle_wait = 0.01
"""Client side: The environment evolves with a time resolution equal to environment_dt"""
environment_dt = 1 / 30
"""Server side: Running inference on (at most) environment_dt"""
inference_latency = environment_dt
"""Supported policies"""
supported_policies = ["act", "smolvla"]

View File

@@ -0,0 +1,128 @@
import logging
import logging.handlers
import os
import time
from typing import Any
import torch
def setup_logging(prefix: str, info_bracket: str):
"""Sets up logging"""
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
# Delete any existing prefix_* log files
for old_log_file in os.listdir("logs"):
if old_log_file.startswith(prefix) and old_log_file.endswith(".log"):
try:
os.remove(os.path.join("logs", old_log_file))
print(f"Deleted old log file: {old_log_file}")
except Exception as e:
print(f"Failed to delete old log file {old_log_file}: {e}")
# Set up logging with both console and file output
logger = logging.getLogger(prefix)
# Prevent propagation to root logger to avoid duplicate messages
logger.propagate = False
logger.setLevel(logging.INFO)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(
logging.Formatter(
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logger.addHandler(console_handler)
# File handler - creates a new log file for each run
file_handler = logging.handlers.RotatingFileHandler(
f"logs/policy_server_{int(time.time())}.log",
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
)
file_handler.setFormatter(
logging.Formatter(
f"%(asctime)s.%(msecs)03d [{info_bracket}] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logger.addHandler(file_handler)
return logger
class TimedData:
def __init__(self, timestamp: float, data: Any, timestep: int):
"""Initialize a TimedData object.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
timestep: The timestep of the data.
"""
self.timestamp = timestamp
self.data = data
self.timestep = timestep
def get_data(self):
return self.data
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
class TimedAction(TimedData):
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
def get_action(self):
return self.get_data()
class TimedObservation(TimedData):
def __init__(
self,
timestamp: float,
observation: dict[str, torch.Tensor],
timestep: int,
transfer_state: int = 0,
must_go: bool = False,
):
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
self.transfer_state = transfer_state
self.must_go = must_go
def get_observation(self):
return self.get_data()
class TinyPolicyConfig:
def __init__(
self,
policy_type: str = "act",
pretrained_name_or_path: str = "fracapuano/act_so100_test",
device: str = "cpu",
):
self.policy_type = policy_type
self.pretrained_name_or_path = pretrained_name_or_path
self.device = device
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
"""Check if two observation states are similar, under a tolerance threshold"""
return torch.linalg.norm(obs1_state - obs2_state) < atol
def observations_similar(obs1: TimedObservation, obs2: TimedObservation, atol: float = 1) -> bool:
"""Check if two observations are similar, under a tolerance threshold"""
obs1_state = obs1.get_observation()["observation.state"]
obs2_state = obs2.get_observation()["observation.state"]
return _compare_observation_states(obs1_state, obs2_state, atol=atol)

View File

@@ -0,0 +1,429 @@
import itertools
import pickle # nosec
import time
from concurrent import futures
from queue import Queue
from typing import Generator, List, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from datasets import load_dataset
from lerobot.common.policies.factory import get_policy_class
from lerobot.scripts.server.constants import environment_dt, idle_wait, inference_latency, supported_policies
from lerobot.scripts.server.helpers import (
TimedAction,
TimedObservation,
TinyPolicyConfig,
observations_similar,
setup_logging,
)
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server"
info_bracket = "SERVER"
logger = setup_logging(prefix, info_bracket)
def __init__(self):
# Initialize dataset action generator (to debug this first version, will be removed in the future)
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
self._setup_server()
self.actions_per_chunk = 20
self.actions_overlap = 10
self.running = True
def _setup_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self.observation_queue = Queue(maxsize=1)
self._predicted_timesteps = set()
self._predicted_observations = Queue(maxsize=1)
def Ready(self, request, context): # noqa: N802
client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready")
self._setup_server()
return async_inference_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client"""
client_id = context.peer()
self.logger.debug(f"Receiving policy instructions from {client_id}")
policy_specs = pickle.loads(request.data) # nosec
assert isinstance(policy_specs, TinyPolicyConfig), (
f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}"
)
self.logger.info(
f"Policy type: {policy_specs.policy_type} | "
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
f"Device: {policy_specs.device}"
)
assert policy_specs.policy_type in supported_policies, (
f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}"
)
self.device = policy_specs.device
self.policy_type = policy_specs.policy_type # act, pi0, etc.
policy_class = get_policy_class(self.policy_type)
start = time.time()
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
self.policy.to(self.device)
end = time.time()
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
client_id = context.peer()
self.logger.debug(f"Receiving observations from {client_id}")
for observation in request_iterator:
receive_time = time.time()
timed_observation = pickle.loads(observation.data) # nosec
deserialize_time = time.time()
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
if not self._maybe_enqueue_observation(timed_observation):
continue
queue_time = time.time()
obs_timestep = timed_observation.get_timestep()
obs_timestamp = timed_observation.get_timestamp()
self.logger.info(
f"Received observation #{obs_timestep} | "
f"Client timestamp: {obs_timestamp:.6f} | "
f"Server timestamp: {receive_time:.6f} | "
)
if not hasattr(self, "previous_obs_timestamp"):
self.previous_obs_timestamp = obs_timestamp
self.logger.debug(
f"1/DeltaObsT (~frequency): {1 / (1e-6 + obs_timestamp - self.previous_obs_timestamp):.6f} Hz| "
f"Network latency: {receive_time - obs_timestamp:.6f}s | "
f"Deserialization time: {deserialize_time - receive_time:.6f}s | "
f"Queue time: {queue_time - deserialize_time:.6f}s | "
)
self.previous_obs_timestamp = obs_timestamp
return async_inference_pb2.Empty()
def StreamActions(self, request, context): # noqa: N802
"""Stream actions to the robot client"""
client_id = context.peer()
self.logger.debug(f"Client {client_id} connected for action streaming")
# Generate action based on the most recent observation and its timestep
try:
obs = self.observation_queue.get()
self.logger.info(
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
)
if obs:
self.last_predicted_obs = obs
self._predicted_timesteps.add(obs.get_timestep())
start_time = time.time()
action_chunk = self._predict_action_chunk(obs)
# action_chunk = self._read_action_chunk(obs)
inference_time = time.time() - start_time
start_time = time.time()
action_bytes = pickle.dumps(action_chunk) # nosec
serialize_time = time.time() - start_time
# Create and return the Action
action = async_inference_pb2.Action(transfer_state=obs.transfer_state, data=action_bytes)
self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | Inference time: {inference_time:.6f}s |"
)
self.logger.debug(
f"Action chunk #{obs.get_timestep()} generated | "
f"Inference time: {inference_time:.6f}s |"
f"Serialize time: {serialize_time:.6f}s |"
f"Total time: {inference_time + serialize_time:.6f}s"
)
yield action
else:
self.logger.warning("No observation in queue yet!")
time.sleep(idle_wait)
except Exception as e:
self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty()
def _enqueue_and_go(self, obs: TimedObservation):
# If queue is full, get the old observation to make room
if self.observation_queue.full():
# pops from queue
_ = self.observation_queue.get_nowait()
self.logger.debug("Observation queue was full, removed oldest observation")
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(obs)
return True
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
if obs.get_timestep() in self._predicted_timesteps:
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
return False
elif observations_similar(obs, previous_obs, atol=1):
self.logger.debug(
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
)
return False
else:
return True
def _maybe_enqueue_observation(self, obs: TimedObservation) -> bool:
"""Enqueue an observation if it must go through processing, otherwise skip it.
Observations not in queue are never run through the policy network"""
if obs.must_go or not hasattr(self, "last_predicted_obs"):
self.logger.info(f"[MUST GO] Enqueued observation #{obs.get_timestep()} for direct processing!")
return self._enqueue_and_go(obs)
else:
if self._obs_sanity_checks(obs, self.last_predicted_obs):
return self._enqueue_and_go(obs)
else:
return False
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
]
@torch.no_grad()
def _run_act_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Run ACT-like policies"""
start_time = time.time()
# prepare observation for policy forward pass
batch = self.policy.normalize_inputs(observation)
normalize_time = time.time()
self.logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s")
if self.policy.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.policy.config.image_features]
prep_time = time.time()
self.logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s")
# forward pass outputs up to policy.config.n_action_steps != actions_per_chunk
actions = self.policy.model(batch)[0][:, : self.actions_per_chunk]
actions = self.policy.unnormalize_outputs({"action": actions})["action"]
end_time = time.time()
self.logger.info(f"[ACT] Action chunk generation total time: {end_time - start_time:.6f}s")
return actions
@torch.no_grad()
def _run_pi0_policy(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Run PI0-like policies"""
raise NotImplementedError("PI0 policy not implemented yet")
@torch.no_grad()
def _run_smolvla_policy(
self, observation: dict[str, torch.Tensor], noise: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Run smolvla-like policies"""
observation = self.policy.normalize_inputs(observation)
images, img_masks = self.policy.prepare_images(observation)
state = self.policy.prepare_state(observation)
lang_tokens, lang_masks = self.policy.prepare_language(observation)
actions = self.policy.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.policy.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.policy.unnormalize_outputs(
{"action": actions, "robot_type": [self.policy.config.robot_type]}
)["action"]
return actions
def _get_action_chunk(
self, observation: dict[str, torch.Tensor], policy_type: str = "act"
) -> torch.Tensor:
"""Get an action chunk from the policy"""
if policy_type == "act":
return self._run_act_policy(observation)
elif policy_type == "smolvla":
return self._run_smolvla_policy(observation)
else:
raise ValueError(f"Policy class {policy_type} not supported")
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action based on the observation"""
"""1. Prepare observation"""
start_time = time.time()
observation = {
"robot_type": [self.policy.config.robot_type],
}
for k, v in observation_t.get_observation().items():
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions
if "image" in k:
# Add batch dimension first, then reorder to NCHW format, then normalize to [0, 1]
observation[k] = (
v.unsqueeze(0).permute(0, 3, 1, 2).to(self.device, non_blocking=True) / 255.0
)
else:
observation[k] = v.unsqueeze(0).to(self.device, non_blocking=True)
else:
observation[k] = v # textual instructions are passed as a list of strings
prep_time = time.time()
self.logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s")
"""2. Get action chunk"""
action_tensor = self._get_action_chunk(observation, self.policy_type)
action_tensor = action_tensor.squeeze(0)
# Move to CPU before serializing
action_tensor = action_tensor.cpu()
post_inference_time = time.time()
self.logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s")
if action_tensor.dim() == 1:
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
action_tensor = action_tensor.repeat(self.actions_per_chunk, 1)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
chunk_time = time.time()
self.logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s")
time.sleep(
max(0, inference_latency - max(0, chunk_time - start_time))
) # sleep to control inference latency
return action_chunk
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
"""Stream chunks of actions from a prerecorded dataset.
Returns:
Generator that yields chunks of actions from the dataset
"""
import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
# 1. Select the action column only, where you will find tensors with 6 elements
actions = dataset["action"]
action_indices = torch.arange(len(actions))
# 2. Chunk the iterable of tensors into chunks with 10 elements each
# sending only first element for debugging
indices_chunks = action_indices.unfold(
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
)
for idx_chunk in indices_chunks:
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
def _read_action_chunk(self, observation: Optional[TimedObservation] = None) -> list[TimedAction]:
"""Dummy function for predicting action chunk given observation.
Instead of computing actions on-the-fly, this method streams
actions from a prerecorded dataset.
"""
import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
start_time = time.time()
if not observation:
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
# Get chunk of actions from the generator
actions_chunk = next(self.action_generator)
# Return a list of TimedActions, with timestamps starting from the observation timestamp
actions_chunk = self._time_action_chunk(
observation.get_timestamp(), actions_chunk, observation.get_timestep()
)
chunk_time = time.time()
self.logger.debug(f"Action chunk creation time: {chunk_time - start_time:.6f}s")
# slow action generation, emulates inference time
time.sleep(max(0, inference_latency - max(0, chunk_time - start_time)))
return actions_chunk
def stop(self):
"""Stop the server"""
self.running = False
self.logger.info("Server stopping...")
def serve():
port = 8080
# Create the server instance first
policy_server = PolicyServer()
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"[::]:{port}")
server.start()
policy_server.logger.info(f"PolicyServer started on port {port}")
try:
# Use the running attribute to control server lifetime
while policy_server.running:
time.sleep(1) # Check every second instead of sleeping indefinitely
except KeyboardInterrupt:
policy_server.stop()
policy_server.logger.info("Keyboard interrupt received")
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,608 @@
import argparse
import os
import pickle # nosec
import threading
import time
from queue import Empty, Queue
from typing import Callable, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.scripts.server.constants import environment_dt, idle_wait
from lerobot.scripts.server.helpers import TimedAction, TimedObservation, TinyPolicyConfig, setup_logging
class RobotClient:
prefix = "robot_client"
info_bracket = "CLIENT"
logger = setup_logging(prefix, info_bracket)
def __init__(
self,
server_address: Optional[str] = None,
policy_type: str = "smolvla",
pretrained_name_or_path: str = "lerobot/smolvla_base",
policy_device: str = "cuda",
chunk_size_threshold: float = 0.5,
robot: str = "so100",
):
# Use environment variable if server_address is not provided
if server_address is None:
server_address = os.getenv("SERVER_ADDRESS", "localhost:8080")
self.logger.info(f"No server address provided, using default address: {server_address}")
self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device)
self.channel = grpc.insecure_channel(server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {server_address}")
self.running = False
self.must_go = True # does the observation qualify for direct processing on the policy server?
self.latest_action = -1
self.action_chunk_size = -1
self._chunk_size_threshold = chunk_size_threshold
self.action_queue = Queue()
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
start_time = time.time()
self.robot = make_robot(robot)
self.robot.connect()
connect_time = time.time()
self.logger.info(f"Robot connection time: {connect_time - start_time:.4f}s")
time.sleep(idle_wait) # sleep waiting for cameras to activate
self.logger.info("Robot connected and ready")
def timestamps(self):
"""Get the timestamps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue])
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.time()
self.stub.Ready(async_inference_pb2.Empty())
end_time = time.time()
self.logger.info(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(
transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes
)
self.logger.info("Sending policy instructions to policy server")
self.logger.info(
f"Policy type: {self.policy_config.policy_type} | "
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
f"Device: {self.policy_config.device}"
)
self.stub.SendPolicyInstructions(policy_setup)
self.running = True
self.available_actions_size = []
return True
except grpc.RpcError as e:
self.logger.error(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self.running = False
self.robot.disconnect()
self.logger.info("Robot disconnected")
self.channel.close()
self.logger.info("Client stopped, channel closed")
def send_observation(
self,
obs: TimedObservation,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
self.logger.warning("Client not running")
return False
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
start_time = time.time()
observation_bytes = pickle.dumps(obs)
serialize_time = time.time()
self.logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s")
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
try:
send_start = time.time()
_ = self.stub.SendObservations(iter([observation]))
send_end = time.time()
obs_timestep = obs.get_timestep()
self.logger.info(
f"Sent observation #{obs_timestep} | "
f"Serialize time: {serialize_time - start_time:.6f}s | "
f"Network time: {send_end - send_start:.6f}s | "
f"Total time: {send_end - start_time:.6f}s"
)
self.last_obs_sent_time = send_end
return True
except grpc.RpcError as e:
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
return False
def _validate_action(self, action: TimedAction):
"""Received actions are keps only when they have been produced for now or later, never before"""
return not action.get_timestep() <= self.latest_action
def _inspect_action_queue(self):
queue_size = self.action_queue.qsize()
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
return queue_size, timestamps
def _update_action_queue(self, actions: list[TimedAction]):
"""Update the action queue with new actions, without ever emptying the queue"""
new_queue = Queue()
for action in actions:
if self._validate_action(action):
new_queue.put(action)
self.action_queue = new_queue
def _aggregate_action_queues(
self,
incoming_actions: list[TimedAction],
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
# TODO(fracapuano): move outside of the function and make aggregate_fn an always required argument
if not aggregate_fn:
# default aggregate function: take the latest action
def aggregate_fn(x1, x2):
return x2
action_intersections: list[torch.Tensor] = []
current_action_queue = {
action.get_timestep(): action.get_action() for action in self.action_queue.queue
}
for new_action in incoming_actions:
if new_action.get_timestep() in current_action_queue:
# TODO(fracapuano): There is probably a way to do this with broadcasting of the two action tensors
action_intersections.append(
TimedAction(
timestamp=new_action.get_timestamp(),
action=aggregate_fn(
current_action_queue[new_action.get_timestep()], new_action.get_action()
),
timestep=new_action.get_timestep(),
)
)
else:
action_intersections.append(new_action)
new_queue = Queue()
for action in action_intersections:
if self._validate_action(action):
new_queue.put(action)
self.action_queue = new_queue
def _clear_action_queue(self):
"""Clear the existing queue"""
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming valid actions"""
start_time = time.time()
valid_count = 0
for action in actions:
if self._validate_action(action):
self.action_queue.put(action)
valid_count += 1
end_time = time.time()
self.logger.debug(
f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s"
)
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
self._clear_action_queue()
self._fill_action_queue(actions)
def receive_actions(self):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
receive_time = time.time()
# Deserialize bytes back into list[TimedAction]
deserialize_start = time.time()
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_end = time.time()
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
start_time = time.time()
self.logger.info(f"Current latest action: {self.latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [self.latest_action] # queue was empty
# Log incoming actions
incoming_timesteps = [a.get_timestep() for a in timed_actions]
# Calculate network latency if we have matching observations
if len(timed_actions) > 0:
first_action_timestep = timed_actions[0].get_timestep()
server_to_client_latency = receive_time - self.last_obs_sent_time
self.logger.info(
f"Received action chunk for step #{first_action_timestep} | "
f"Latest action: #{self.latest_action} | "
f"Network latency (server->client): {server_to_client_latency:.6f}s | "
f"Deserialization time: {deserialize_end - deserialize_start:.6f}s"
)
# Update action queue
start_time = time.time()
self._update_action_queue(timed_actions)
queue_update_time = time.time() - start_time
self.must_go = (
True # after receiving actions, next empty queue triggers must-go processing!
)
# Get queue state after changes
new_size, new_timesteps = self._inspect_action_queue()
self.logger.info(
f"Queue update complete ({queue_update_time:.6f}s) | "
f"Before: {old_size} items | "
f"After: {new_size} items | "
)
self.logger.info(
f"Latest action: {self.latest_action} | "
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
)
except grpc.RpcError as e:
self.logger.error(f"Error receiving actions: {e}")
# Avoid tight loop on action receiver error
time.sleep(idle_wait)
def _actions_available(self):
"""Check if there are actions available in the queue"""
return not self.action_queue.empty()
def _get_next_action(self) -> Optional[TimedAction]:
"""Get the next action from the queue"""
try:
action = self.action_queue.get_nowait()
return action
except Empty:
return None
def _perform_action(self, timed_action: TimedAction):
self.robot.send_action(timed_action.get_action())
self.latest_action = timed_action.get_timestep()
self.logger.debug(
f"Ts={timed_action.get_timestamp()} | "
f"Action #{timed_action.get_timestep()} performed | "
f"Queue size: {self.action_queue.qsize()}"
)
def execute_actions(self):
"""Continuously execute actions from the queue"""
import warnings
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
# Wait at barrier for synchronized start
self.start_barrier.wait()
time.sleep(idle_wait) # wait for observation capture to start
self.logger.info("Action execution thread starting")
while self.running:
# constantly monitor the size of the action queue
self.available_actions_size.append(self.action_queue.qsize())
if self._actions_available():
timed_action = self._get_next_action()
self._perform_action(timed_action)
time.sleep(environment_dt)
else:
self.logger.debug("No action available | Sleeping")
time.sleep(idle_wait)
def stream_observations(self, get_observation_fn):
"""Continuously stream observations to the server"""
import warnings
warnings.warn("This method is deprecated! Will be removed soon!", stacklevel=2)
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Observation streaming thread starting")
while self.running:
try:
# Get serialized observation bytes from the function
start_time = time.time()
observation = get_observation_fn()
obs_capture_time = time.time() - start_time
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
if not hasattr(self, "last_obs_timestamp"):
self.last_obs_timestamp = observation.get_timestamp()
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
self.logger.info(
f"Ts={obs_timestamp} | "
f"Captured observation #{obs_timestep} | "
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
)
self.last_obs_timestamp = obs_timestamp
# Set appropriate transfer state
if obs_timestep == 0:
state = async_inference_pb2.TRANSFER_BEGIN
else:
state = async_inference_pb2.TRANSFER_MIDDLE
time.sleep(environment_dt)
self.send_observation(observation, state)
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
time.sleep(idle_wait)
def control_loop_action(self):
"""Reading and performing actions in local queue"""
self.available_actions_size.append(self.action_queue.qsize())
if self._actions_available():
# Get action from queue
get_start = time.time()
timed_action = self._get_next_action()
get_end = time.time() - get_start
self.logger.debug(
f"Popping action from queue to perform took {get_end:.6f}s | "
f"Queue size: {self.action_queue.qsize()}"
)
self._perform_action(timed_action)
def _ready_to_send_observation(self):
"""Flags when the client is ready to send an observation"""
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
def control_loop_observation(self, get_observation_fn):
try:
# Get serialized observation bytes from the function
start_time = time.time()
observation = get_observation_fn()
obs_capture_time = time.time() - start_time
# If there are no actions left in the queue, the observation must go through processing!
observation.must_go = self.must_go and self.action_queue.empty()
self.logger.debug(f"QUEUE SIZE: {self.action_queue.qsize()} (Must go: {observation.must_go})")
if observation.must_go:
# must-go flag will be set again after receiving actions
self.must_go = False
if not hasattr(self, "last_obs_timestamp"):
self.last_obs_timestamp = observation.get_timestamp()
obs_timestep, obs_timestamp = observation.get_timestep(), observation.get_timestamp()
self.last_obs_timestamp = obs_timestamp
self.logger.info(
f"Ts={obs_timestamp} | "
f"Captured observation #{obs_timestep} | "
f"1/DeltaTs (~frequency)={1 / (1e-6 + obs_timestamp - self.last_obs_timestamp):.6f}"
)
self.logger.debug(f"Capturing observation took {obs_capture_time:.6f}s")
# Set appropriate transfer state
if obs_timestep == 0:
state = async_inference_pb2.TRANSFER_BEGIN
else:
state = async_inference_pb2.TRANSFER_MIDDLE
self.send_observation(observation, state)
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
def control_loop(self, get_observation_fn):
"""Combined function for executing actions and streaming observations"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Control loop thread starting")
control_loops = 0
while self.running:
control_loop_start = time.time()
self.control_loop_action()
"""Control loop: (2) Streaming observations to the remote policy server"""
if self._ready_to_send_observation() or control_loops == 0:
self.control_loop_observation(get_observation_fn)
# Dynamically adjust sleep time to maintain the desired control frequency
time.sleep(max(0, environment_dt - (time.time() - control_loop_start)))
control_loops += 1
def async_client(task_instruction: str, verbose: int = 0):
client = RobotClient()
if client.start():
# Function to get observations from the robot
def get_observation():
observation_content = None
observation_content = client.robot.capture_observation()
observation_content["task"] = [task_instruction]
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
)
return observation
client.logger.info("Starting all threads...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
control_loop_thread.daemon = True
# Start all threads
action_receiver_thread.start()
control_loop_thread.start()
try:
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
client.logger.info("Client stopped")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Robot client for executing tasks via policy server")
parser.add_argument(
"--task",
type=str,
required=True,
help="Task instruction for the robot to execute (e.g., 'fold my tshirt')",
)
parser.add_argument("--verbose", type=int, default=0, help="Verbosity level (default: 0)")
parser.add_argument(
"--server-port-address",
type=str,
default="localhost:8080",
help="Server & port address (default: localhost:8080, or SERVER_ADDRESS env var)",
)
parser.add_argument("--policy-type", type=str, default="smolvla", help="Policy type (default: smolvla)")
parser.add_argument(
"--pretrained-name-or-path",
type=str,
default="lerobot/smolvla_base",
help="Pretrained model name or path (default: lerobot/smolvla_base)",
)
parser.add_argument(
"--policy-device", type=str, default="cuda", help="Device for policy inference (default: cuda)"
)
parser.add_argument(
"--chunk-size-threshold",
type=float,
default=0.5,
help="Chunk size threshold (`g` in the paper, default: 0.5)",
)
parser.add_argument(
"--robot",
type=str,
default="so100",
help="Robot name, as per the `make_robot` function (default: so100)",
)
args = parser.parse_args()
# Create client with parsed arguments
client = RobotClient(
server_address=args.server_address,
policy_type=args.policy_type,
pretrained_name_or_path=args.pretrained_name_or_path,
policy_device=args.policy_device,
chunk_size_threshold=args.chunk_size_threshold,
robot=args.robot,
)
if client.start():
# Function to get observations from the robot
def get_observation():
observation_content = None
observation_content = client.robot.capture_observation()
observation_content["task"] = [args.task]
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=max(client.latest_action, 0)
)
return observation
client.logger.info("Starting all threads...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
control_loop_thread = threading.Thread(target=client.control_loop, args=(get_observation,))
control_loop_thread.daemon = True
# Start all threads
action_receiver_thread.start()
control_loop_thread.start()
try:
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
client.logger.info("Client stopped")

View File

@@ -63,7 +63,7 @@ dependencies = [
"opencv-python-headless>=4.9.0",
"packaging>=24.2",
"av>=14.2.0",
"pymunk>=6.6.0,<7.0.0",
"pymunk>=6.6.0",
"pynput>=1.7.7",
"pyzmq>=26.2.1",
"rerun-sdk>=0.21.0",
@@ -86,7 +86,6 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"]
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
stretch = [
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",

View File

@@ -45,7 +45,12 @@ def test_available_policies():
This test verifies that the class attribute `name` for all policies is
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
policy_classes = [
ACTPolicy,
DiffusionPolicy,
TDMPCPolicy,
VQBeTPolicy,
]
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies