multi-node openpi commit

This commit is contained in:
Leon998
2026-03-17 23:05:23 +08:00
parent 28833f0c0f
commit 7411e0e004
156 changed files with 33951 additions and 1 deletions

View File

@@ -0,0 +1,70 @@
# Dockerfile for the Aloha real environment.
# Build the container:
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
SHELL ["/bin/bash", "-c"]
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
cmake \
curl \
libffi-dev \
python3-rosdep \
python3-rosinstall \
python3-rosinstall-generator \
whiptail \
git \
wget \
openssh-client \
ros-noetic-cv-bridge \
ros-noetic-usb-cam \
ros-noetic-realsense2-camera \
keyboard-configuration
WORKDIR /root
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
RUN chmod +x xsarm_amd64_install.sh
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
# Install python 3.10 because this ROS image comes with 3.8
RUN mkdir /python && \
cd /python && \
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
tar -zxvf Python-3.10.14.tgz && \
cd Python-3.10.14 && \
ls -lhR && \
./configure --enable-optimizations && \
make install && \
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
cd ~ && rm -rf /python && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
ENV UV_HTTP_TIMEOUT=120
ENV UV_LINK_MODE=copy
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
WORKDIR /app
# Create an entrypoint script to run the setup commands, followed by the command passed in.
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
#!/bin/bash
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
EOF
RUN chmod +x /usr/local/bin/entrypoint.sh
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["python3", "/app/examples/aloha_real/main.py"]

View File

@@ -0,0 +1,126 @@
# Run Aloha (Real Robot)
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
## Prerequisites
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
## With Docker
```bash
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
docker compose -f examples/aloha_real/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_real/.venv
source examples/aloha_real/.venv/bin/activate
uv pip sync examples/aloha_real/requirements.txt
uv pip install -e packages/openpi-client
# Run the robot
python -m examples.aloha_real.main
```
Terminal window 2:
```bash
roslaunch aloha ros_nodes.launch
```
Terminal window 3:
```bash
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
```
## **ALOHA Checkpoint Guide**
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
While weve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects weve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
---
### **Toast Task**
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
- **Prompt**: "take the toast out of the toaster"
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
- **Object Distribution**:
- Works on both real toast and rubber fake toast
- Compatible with standard 2-slice toasters
- Works with plates of varying colors
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
- The toaster should be positioned in the top-left quadrant of the workspace.
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
- The plate should be placed roughly in the lower-center of the workspace.
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
### **Towel Task**
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
- **Prompt**: "fold the towel"
- **Object Distribution**:
- Works on towels of varying solid colors
- Performance is worse on heavily textured or striped towels
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
- The towel should be flattened and roughly centered on the table.
- Choose a towel that does not blend in with the table surface.
### **Tupperware Task**
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
- **Prompt**: "open the tupperware and put the food on the plate"
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
- **Object Distribution**:
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
- The policy has seen plates of varying solid colors.
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
- Positioning:
- Tupperware should be on the left.
- Plate should be on the right or bottom.
- The tupperware flap should point toward the plate.
## Training on your own Aloha dataset
1. Convert the dataset to the LeRobot dataset v2.0 format.
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
2. Define a training config that uses the custom dataset.
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoints asset directory within the AssetsConfig.

View File

@@ -0,0 +1,66 @@
# Run with:
# docker compose -f examples/aloha_real/compose.yml up --build
services:
runtime:
image: aloha_real
depends_on:
- aloha_ros_nodes
- ros_master
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
aloha_ros_nodes:
image: aloha_real
depends_on:
- ros_master
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- /dev:/dev
command: roslaunch --wait aloha ros_nodes.launch
ros_master:
image: ros:noetic-robot
network_mode: host
privileged: true
command:
- roscore
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,71 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
### Task parameters
### ALOHA fixed constants
DT = 0.001
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
############################ Helper functions ############################
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
MASTER_POS2JOINT = (
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
PUPPET_POS2JOINT = (
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2

View File

@@ -0,0 +1,272 @@
"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""
import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
import tyro
@dataclasses.dataclass(frozen=True)
class DatasetConfig:
use_videos: bool = True
tolerance_s: float = 0.0001
image_writer_processes: int = 10
image_writer_threads: int = 5
video_backend: str | None = None
DEFAULT_DATASET_CONFIG = DatasetConfig()
def create_empty_dataset(
repo_id: str,
robot_type: str,
mode: Literal["video", "image"] = "video",
*,
has_velocity: bool = False,
has_effort: bool = False,
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
motors = [
"right_waist",
"right_shoulder",
"right_elbow",
"right_forearm_roll",
"right_wrist_angle",
"right_wrist_rotate",
"right_gripper",
"left_waist",
"left_shoulder",
"left_elbow",
"left_forearm_roll",
"left_wrist_angle",
"left_wrist_rotate",
"left_gripper",
]
cameras = [
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
]
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
}
if has_velocity:
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
if has_effort:
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
for cam in cameras:
features[f"observation.images.{cam}"] = {
"dtype": mode,
"shape": (3, 480, 640),
"names": [
"channels",
"height",
"width",
],
}
if Path(LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
return LeRobotDataset.create(
repo_id=repo_id,
fps=50,
robot_type=robot_type,
features=features,
use_videos=dataset_config.use_videos,
tolerance_s=dataset_config.tolerance_s,
image_writer_processes=dataset_config.image_writer_processes,
image_writer_threads=dataset_config.image_writer_threads,
video_backend=dataset_config.video_backend,
)
def get_cameras(hdf5_files: list[Path]) -> list[str]:
with h5py.File(hdf5_files[0], "r") as ep:
# ignore depth channel, not currently handled
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
def has_velocity(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/qvel" in ep
def has_effort(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/effort" in ep
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
imgs_per_cam = {}
for camera in cameras:
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
if uncompressed:
# load all images in RAM
imgs_array = ep[f"/observations/images/{camera}"][:]
else:
import cv2
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array
return imgs_per_cam
def load_raw_episode_data(
ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
effort = None
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
imgs_per_cam = load_raw_images_per_camera(
ep,
[
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
],
)
return imgs_per_cam, state, action, velocity, effort
def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode(task=task)
return dataset
def port_aloha(
raw_dir: Path,
repo_id: str,
raw_repo_id: str | None = None,
task: str = "DEBUG",
*,
episodes: list[int] | None = None,
push_to_hub: bool = True,
is_mobile: bool = False,
mode: Literal["video", "image"] = "image",
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
if raw_repo_id is None:
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
download_raw(raw_dir, repo_id=raw_repo_id)
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if is_mobile else "aloha",
mode=mode,
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
dataset_config=dataset_config,
)
dataset = populate_dataset(
dataset,
hdf5_files,
task=task,
episodes=episodes,
)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
tyro.cli(port_aloha)

View File

@@ -0,0 +1,57 @@
from typing import List, Optional # noqa: UP035
import einops
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override
from examples.aloha_real import real_env as _real_env
class AlohaRealEnvironment(_environment.Environment):
"""An environment for an Aloha robot on real hardware."""
def __init__(
self,
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
render_height: int = 224,
render_width: int = 224,
) -> None:
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
self._render_height = render_height
self._render_width = render_width
self._ts = None
@override
def reset(self) -> None:
self._ts = self._env.reset()
@override
def is_episode_complete(self) -> bool:
return False
@override
def get_observation(self) -> dict:
if self._ts is None:
raise RuntimeError("Timestep is not set. Call reset() first.")
obs = self._ts.observation
for k in list(obs["images"].keys()):
if "_depth" in k:
del obs["images"][k]
for cam_name in obs["images"]:
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
)
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
return {
"state": obs["qpos"],
"images": obs["images"],
}
@override
def apply_action(self, action: dict) -> None:
self._ts = self._env.step(action["actions"])

View File

@@ -0,0 +1,51 @@
import dataclasses
import logging
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import tyro
from examples.aloha_real import env as _env
@dataclasses.dataclass
class Args:
host: str = "0.0.0.0"
port: int = 8000
action_horizon: int = 25
num_episodes: int = 1
max_episode_steps: int = 1000
def main(args: Args) -> None:
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
metadata = ws_client_policy.get_server_metadata()
runtime = _runtime.Runtime(
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=ws_client_policy,
action_horizon=args.action_horizon,
)
),
subscribers=[],
max_hz=50,
num_episodes=args.num_episodes,
max_episode_steps=args.max_episode_steps,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,176 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
import collections
import time
from typing import Optional, List
import dm_env
from interbotix_xs_modules.arm import InterbotixManipulatorXS
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
from examples.aloha_real import constants
from examples.aloha_real import robot_utils
# This is the reset position that is used by the standard Aloha runtime.
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
class RealEnv:
"""
Environment for real robot bi-manual manipulation
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
"""
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
# reset_position = START_ARM_POSE[:6]
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
self.puppet_bot_left = InterbotixManipulatorXS(
robot_model="vx300s",
group_name="arm",
gripper_name="gripper",
robot_name="puppet_left",
init_node=init_node,
)
self.puppet_bot_right = InterbotixManipulatorXS(
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
)
if setup_robots:
self.setup_robots()
self.recorder_left = robot_utils.Recorder("left", init_node=False)
self.recorder_right = robot_utils.Recorder("right", init_node=False)
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
self.gripper_command = JointSingleCommand(name="gripper")
def setup_robots(self):
robot_utils.setup_puppet_bot(self.puppet_bot_left)
robot_utils.setup_puppet_bot(self.puppet_bot_right)
def get_qpos(self):
left_qpos_raw = self.recorder_left.qpos
right_qpos_raw = self.recorder_right.qpos
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
] # this is position not joint
right_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
] # this is position not joint
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
def get_qvel(self):
left_qvel_raw = self.recorder_left.qvel
right_qvel_raw = self.recorder_right.qvel
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
def get_effort(self):
left_effort_raw = self.recorder_left.effort
right_effort_raw = self.recorder_right.effort
left_robot_effort = left_effort_raw[:7]
right_robot_effort = right_effort_raw[:7]
return np.concatenate([left_robot_effort, right_robot_effort])
def get_images(self):
return self.image_recorder.get_images()
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
self.gripper_command.cmd = left_gripper_desired_joint
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
right_gripper_desired_pos_normalized
)
self.gripper_command.cmd = right_gripper_desired_joint
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
def _reset_joints(self):
robot_utils.move_arms(
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
)
def _reset_gripper(self):
"""Set to position mode and do position resets: first close then open. Then change back to PWM mode
NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
increase the frequency of motor faults.
"""
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
)
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
)
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
return obs
def get_reward(self):
return 0
def reset(self, *, fake=False):
if not fake:
# Reboot puppet robot gripper motors
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
self._reset_joints()
self._reset_gripper()
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def step(self, action):
state_len = int(len(action) / 2)
left_action = action[:state_len]
right_action = action[state_len:]
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
self.set_gripper_pose(left_action[-1], right_action[-1])
time.sleep(constants.DT)
return dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def get_action(master_bot_left, master_bot_right):
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
# Arm actions
action[:6] = master_bot_left.dxl.joint_states.position[:6]
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
# Gripper actions
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
return action
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)

View File

@@ -0,0 +1,18 @@
Pillow
dm_control
einops
h5py
matplotlib
modern_robotics
msgpack
numpy>=1.22.4,<2.0.0
opencv-python
packaging
pexpect
pyquaternion
pyrealsense2
pyyaml
requests
rospkg
tyro
websockets

View File

@@ -0,0 +1,156 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
catkin-pkg==1.0.0
# via rospkg
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
contourpy==1.1.1
# via matplotlib
cycler==0.12.1
# via matplotlib
distro==1.9.0
# via rospkg
dm-control==1.0.23
# via -r examples/aloha_real/requirements.in
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
docutils==0.20.1
# via catkin-pkg
einops==0.8.0
# via -r examples/aloha_real/requirements.in
etils==1.3.0
# via mujoco
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
h5py==3.11.0
# via -r examples/aloha_real/requirements.in
idna==3.10
# via requests
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.7.5
# via -r examples/aloha_real/requirements.in
mdurl==0.1.2
# via markdown-it-py
modern-robotics==1.1.1
# via -r examples/aloha_real/requirements.in
msgpack==1.1.0
# via -r examples/aloha_real/requirements.in
mujoco==3.2.3
# via dm-control
numpy==1.24.4
# via
# -r examples/aloha_real/requirements.in
# contourpy
# dm-control
# dm-env
# h5py
# labmaze
# matplotlib
# modern-robotics
# mujoco
# opencv-python
# pyquaternion
# scipy
opencv-python==4.10.0.84
# via -r examples/aloha_real/requirements.in
packaging==24.2
# via
# -r examples/aloha_real/requirements.in
# matplotlib
pexpect==4.9.0
# via -r examples/aloha_real/requirements.in
pillow==10.4.0
# via
# -r examples/aloha_real/requirements.in
# matplotlib
protobuf==5.29.1
# via dm-control
ptyprocess==0.7.0
# via pexpect
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.1.4
# via
# catkin-pkg
# dm-control
# matplotlib
pyquaternion==0.9.9
# via -r examples/aloha_real/requirements.in
pyrealsense2==2.55.1.6486
# via -r examples/aloha_real/requirements.in
python-dateutil==2.9.0.post0
# via
# catkin-pkg
# matplotlib
pyyaml==6.0.2
# via
# -r examples/aloha_real/requirements.in
# rospkg
requests==2.32.3
# via
# -r examples/aloha_real/requirements.in
# dm-control
rich==13.9.4
# via tyro
rospkg==1.5.1
# via -r examples/aloha_real/requirements.in
scipy==1.10.1
# via dm-control
setuptools==75.3.0
# via
# catkin-pkg
# dm-control
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_real/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_real/requirements.in
zipp==3.20.2
# via etils

View File

@@ -0,0 +1,275 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
from collections import deque
import datetime
import json
import time
from aloha.msg import RGBGrayscaleImage
from cv_bridge import CvBridge
from interbotix_xs_msgs.msg import JointGroupCommand
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
import rospy
from sensor_msgs.msg import JointState
from examples.aloha_real import constants
class ImageRecorder:
def __init__(self, init_node=True, is_debug=False):
self.is_debug = is_debug
self.bridge = CvBridge()
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
if init_node:
rospy.init_node("image_recorder", anonymous=True)
for cam_name in self.camera_names:
setattr(self, f"{cam_name}_rgb_image", None)
setattr(self, f"{cam_name}_depth_image", None)
setattr(self, f"{cam_name}_timestamp", 0.0)
if cam_name == "cam_high":
callback_func = self.image_cb_cam_high
elif cam_name == "cam_low":
callback_func = self.image_cb_cam_low
elif cam_name == "cam_left_wrist":
callback_func = self.image_cb_cam_left_wrist
elif cam_name == "cam_right_wrist":
callback_func = self.image_cb_cam_right_wrist
else:
raise NotImplementedError
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
if self.is_debug:
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
time.sleep(0.5)
def image_cb(self, cam_name, data):
setattr(
self,
f"{cam_name}_rgb_image",
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
)
# setattr(
# self,
# f"{cam_name}_depth_image",
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
# )
setattr(
self,
f"{cam_name}_timestamp",
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
)
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
if self.is_debug:
getattr(self, f"{cam_name}_timestamps").append(
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
)
def image_cb_cam_high(self, data):
cam_name = "cam_high"
return self.image_cb(cam_name, data)
def image_cb_cam_low(self, data):
cam_name = "cam_low"
return self.image_cb(cam_name, data)
def image_cb_cam_left_wrist(self, data):
cam_name = "cam_left_wrist"
return self.image_cb(cam_name, data)
def image_cb_cam_right_wrist(self, data):
cam_name = "cam_right_wrist"
return self.image_cb(cam_name, data)
def get_images(self):
image_dict = {}
for cam_name in self.camera_names:
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
time.sleep(0.00001)
rgb_image = getattr(self, f"{cam_name}_rgb_image")
depth_image = getattr(self, f"{cam_name}_depth_image")
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
image_dict[cam_name] = rgb_image
image_dict[f"{cam_name}_depth"] = depth_image
return image_dict
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
for cam_name in self.camera_names:
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
print(f"{cam_name} {image_freq=:.2f}")
print()
class Recorder:
def __init__(self, side, init_node=True, is_debug=False):
self.secs = None
self.nsecs = None
self.qpos = None
self.effort = None
self.arm_command = None
self.gripper_command = None
self.is_debug = is_debug
if init_node:
rospy.init_node("recorder", anonymous=True)
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_group",
JointGroupCommand,
self.puppet_arm_commands_cb,
)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_single",
JointSingleCommand,
self.puppet_gripper_commands_cb,
)
if self.is_debug:
self.joint_timestamps = deque(maxlen=50)
self.arm_command_timestamps = deque(maxlen=50)
self.gripper_command_timestamps = deque(maxlen=50)
time.sleep(0.1)
def puppet_state_cb(self, data):
self.qpos = data.position
self.qvel = data.velocity
self.effort = data.effort
self.data = data
if self.is_debug:
self.joint_timestamps.append(time.time())
def puppet_arm_commands_cb(self, data):
self.arm_command = data.cmd
if self.is_debug:
self.arm_command_timestamps.append(time.time())
def puppet_gripper_commands_cb(self, data):
self.gripper_command = data.cmd
if self.is_debug:
self.gripper_command_timestamps.append(time.time())
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
joint_freq = 1 / dt_helper(self.joint_timestamps)
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
def get_arm_joint_positions(bot):
return bot.arm.core.joint_states.position[:6]
def get_arm_gripper_positions(bot):
return bot.gripper.core.joint_states.position[6]
def move_arms(bot_list, target_pose_list, move_time=1):
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
for t in range(num_steps):
for bot_id, bot in enumerate(bot_list):
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
time.sleep(constants.DT)
def move_grippers(bot_list, target_pose_list, move_time):
print(f"Moving grippers to {target_pose_list=}")
gripper_command = JointSingleCommand(name="gripper")
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
for t in range(num_steps):
d = {}
for bot_id, bot in enumerate(bot_list):
gripper_command.cmd = traj_list[bot_id][t]
bot.gripper.core.pub_single.publish(gripper_command)
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
f.write(json.dumps(d) + "\n")
time.sleep(constants.DT)
def setup_puppet_bot(bot):
bot.dxl.robot_reboot_motors("single", "gripper", True)
bot.dxl.robot_set_operating_modes("group", "arm", "position")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_on(bot)
def setup_master_bot(bot):
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_off(bot)
def set_standard_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def set_low_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def torque_off(bot):
bot.dxl.robot_torque_enable("group", "arm", False)
bot.dxl.robot_torque_enable("single", "gripper", False)
def torque_on(bot):
bot.dxl.robot_torque_enable("group", "arm", True)
bot.dxl.robot_torque_enable("single", "gripper", True)
# for DAgger
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
print("\nSyncing!")
# activate master arms
torque_on(master_bot_left)
torque_on(master_bot_right)
# get puppet arm positions
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
# get puppet gripper positions
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
# move master arms to puppet positions
move_arms(
[master_bot_left, master_bot_right],
[puppet_left_qpos, puppet_right_qpos],
move_time=1,
)
# move master grippers to puppet positions
move_grippers(
[master_bot_left, master_bot_right],
[puppet_left_gripper, puppet_right_gripper],
move_time=1,
)

View File

@@ -0,0 +1,36 @@
import matplotlib.pyplot as plt
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoDisplay(_subscriber.Subscriber):
"""Displays video frames."""
def __init__(self) -> None:
self._ax: plt.Axes | None = None
self._plt_img: plt.Image | None = None
@override
def on_episode_start(self) -> None:
plt.ion()
self._ax = plt.subplot()
self._plt_img = None
@override
def on_step(self, observation: dict, action: dict) -> None:
assert self._ax is not None
im = observation["image"][0] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
if self._plt_img is None:
self._plt_img = self._ax.imshow(im)
else:
self._plt_img.set_data(im)
plt.pause(0.001)
@override
def on_episode_end(self) -> None:
plt.ioff()
plt.close()

View File

@@ -0,0 +1,41 @@
# Dockerfile for the Aloha simulation environment.
# Build the container:
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev
ENV MUJOCO_GL=egl
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]

View File

@@ -0,0 +1,36 @@
# Run Aloha Sim
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/aloha_sim/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_sim/.venv
source examples/aloha_sim/.venv/bin/activate
uv pip sync examples/aloha_sim/requirements.txt
uv pip install -e packages/openpi-client
# Run the simulation
MUJOCO_GL=egl python examples/aloha_sim/main.py
```
Note: If you are seeing EGL errors, you may need to install the following dependencies:
```bash
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env ALOHA_SIM
```

View File

@@ -0,0 +1,42 @@
# Run with:
# docker compose -f examples/aloha_sim/compose.yml up --build
services:
runtime:
image: aloha_sim
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_sim/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,56 @@
import gym_aloha # noqa: F401
import gymnasium
import numpy as np
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override
class AlohaSimEnvironment(_environment.Environment):
"""An environment for an Aloha robot in simulation."""
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
np.random.seed(seed)
self._rng = np.random.default_rng(seed)
self._gym = gymnasium.make(task, obs_type=obs_type)
self._last_obs = None
self._done = True
self._episode_reward = 0.0
@override
def reset(self) -> None:
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = False
self._episode_reward = 0.0
@override
def is_episode_complete(self) -> bool:
return self._done
@override
def get_observation(self) -> dict:
if self._last_obs is None:
raise RuntimeError("Observation is not set. Call reset() first.")
return self._last_obs # type: ignore
@override
def apply_action(self, action: dict) -> None:
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = terminated or truncated
self._episode_reward = max(self._episode_reward, reward)
def _convert_observation(self, gym_obs: dict) -> dict:
img = gym_obs["pixels"]["top"]
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
# Convert axis order from [H, W, C] --> [C, H, W]
img = np.transpose(img, (2, 0, 1))
return {
"state": gym_obs["agent_pos"],
"images": {"cam_high": img},
}

View File

@@ -0,0 +1,55 @@
import dataclasses
import logging
import pathlib
import env as _env
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import saver as _saver
import tyro
@dataclasses.dataclass
class Args:
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
task: str = "gym_aloha/AlohaTransferCube-v0"
seed: int = 0
action_horizon: int = 10
host: str = "0.0.0.0"
port: int = 8000
display: bool = False
def main(args: Args) -> None:
runtime = _runtime.Runtime(
environment=_env.AlohaSimEnvironment(
task=args.task,
seed=args.seed,
),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=_websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
),
action_horizon=args.action_horizon,
)
),
subscribers=[
_saver.VideoSaver(args.out_dir),
],
max_hz=50,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,8 @@
gym-aloha
imageio
matplotlib
msgpack
numpy>=1.22.4,<2.0.0
typing-extensions
tyro
websockets

View File

@@ -0,0 +1,132 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
cloudpickle==3.1.0
# via gymnasium
contourpy==1.3.1
# via matplotlib
cycler==0.12.1
# via matplotlib
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
farama-notifications==0.0.4
# via gymnasium
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
gym-aloha==0.1.1
# via -r examples/aloha_sim/requirements.in
gymnasium==1.0.0
# via gym-aloha
idna==3.10
# via requests
imageio==2.36.1
# via
# -r examples/aloha_sim/requirements.in
# gym-aloha
imageio-ffmpeg==0.5.1
# via imageio
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.9.3
# via -r examples/aloha_sim/requirements.in
mdurl==0.1.2
# via markdown-it-py
msgpack==1.1.0
# via -r examples/aloha_sim/requirements.in
mujoco==2.3.7
# via
# dm-control
# gym-aloha
numpy==1.26.4
# via
# -r examples/aloha_sim/requirements.in
# contourpy
# dm-control
# dm-env
# gymnasium
# imageio
# labmaze
# matplotlib
# mujoco
# scipy
packaging==24.2
# via matplotlib
pillow==11.0.0
# via
# imageio
# matplotlib
protobuf==5.29.1
# via dm-control
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.2.0
# via
# dm-control
# matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
requests==2.32.3
# via dm-control
rich==13.9.4
# via tyro
scipy==1.14.1
# via dm-control
setuptools==75.6.0
# via
# dm-control
# imageio-ffmpeg
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.1
# via tyro
typing-extensions==4.12.2
# via
# -r examples/aloha_sim/requirements.in
# gymnasium
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_sim/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_sim/requirements.in

View File

@@ -0,0 +1,40 @@
import logging
import pathlib
import imageio
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoSaver(_subscriber.Subscriber):
"""Saves episode data."""
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
out_dir.mkdir(parents=True, exist_ok=True)
self._out_dir = out_dir
self._images: list[np.ndarray] = []
self._subsample = subsample
@override
def on_episode_start(self) -> None:
self._images = []
@override
def on_step(self, observation: dict, action: dict) -> None:
im = observation["images"]["cam_high"] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
self._images.append(im)
@override
def on_episode_end(self) -> None:
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
out_path = self._out_dir / f"out_{next_idx}.mp4"
logging.info(f"Saving video to {out_path}")
imageio.mimwrite(
out_path,
[np.asarray(x) for x in self._images[:: self._subsample]],
fps=50 // max(1, self._subsample),
)

View File

@@ -0,0 +1,212 @@
from collections import deque
from typing import List, Dict, Optional, Any, Sequence, Deque, Union
import datasets
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def check_final(
last_states: Union[Deque[Sequence[float]], Sequence[Sequence[float]], torch.Tensor],
*,
# 索引与初始状态
arm_dofs: int = 6, # 左臂关节数(这里按你给的 6
gripper_index: int = -1, # 夹爪在向量中的索引(默认最后一维)
mean_initial_arm_state: Optional[Sequence[float]] = (0.0107, 0.0527, 0.0463, -0.0415, 0.0187, 0.0108),
mean_initial_gripper_state: float = 4.8438, # 目前不参与判定,保留以便后续扩展
# 判定阈值(角度阈值用“度”直观易调,内部会转换为弧度)
stability_window: int = 5, # 最近多少帧用于判“没有太大变化”
per_joint_range_deg: float = 2.0, # 窗口内每个关节的最大幅度max-min阈值
mean_speed_deg: float = 0.5, # 邻帧关节差的平均 L2每步阈值度/步)
min_change_from_initial_deg: float = 15.0, # 末帧相对初始的“至少变化量”L2
gripper_closed_thresh: float = 0.8, # 夹爪关闭阈值(数值越小说明越闭合)
) -> bool:
"""
返回 True 表示“到位”:(1) 最近窗口内姿态变化不大 & (2) 夹爪关闭 & (3) 末帧与初始相差足够大。
所有角度的阈值以“度”给出,这里会自动转弧度再比较。
"""
# --- 数据整理为 (N, D) tensor ---
if isinstance(last_states, torch.Tensor):
states = last_states
else:
states = torch.as_tensor(list(last_states), dtype=torch.float32)
if states.ndim != 2:
raise ValueError(f"last_states should be 2D, got shape {tuple(states.shape)}")
N, D = states.shape
if D < arm_dofs:
raise ValueError(f"Expected at least {arm_dofs} dims for arm + gripper, got {D}")
if N < 2:
return False # 样本太少,无法判定稳定
# 取最近窗口
w = min(N, stability_window)
window = states[-w:] # (w, D)
arm = window[:, :arm_dofs] # (w, 6)
last_arm = arm[-1] # (6,)
last_gripper = float(window[-1, gripper_index])
# --- 1) 最近 w 帧“没有太大变化” ---
# 两个指标每关节rangemax-min要小、相邻帧的平均“速度”要小
deg2rad = torch.pi / 180.0
range_tol = per_joint_range_deg * deg2rad
speed_tol = mean_speed_deg * deg2rad
ranges = arm.max(dim=0).values - arm.min(dim=0).values # (6,)
max_range = float(ranges.abs().max()) # 标量
diffs = arm[1:] - arm[:-1] # (w-1, 6)
mean_speed = float(torch.linalg.norm(diffs, dim=1).mean()) # 每步的平均 L2
stable = (max_range <= range_tol) and (mean_speed <= speed_tol)
# --- 2) 夹爪关闭 ---
gripper_closed = (last_gripper < gripper_closed_thresh)
# --- 3) 末帧与“初始”差距要大 ---
init = torch.as_tensor(mean_initial_arm_state, dtype=last_arm.dtype, device=last_arm.device)
if init.numel() != arm_dofs:
raise ValueError(f"mean_initial_arm_state length {init.numel()} != arm_dofs {arm_dofs}")
dist_from_init = float(torch.linalg.norm(last_arm - init))
far_from_init = (dist_from_init >= (min_change_from_initial_deg * deg2rad))
# 组合判定
return bool(stable and gripper_closed and far_from_init)
# return bool(gripper_closed and far_from_init)
def get_last_frames(ds: LeRobotDataset, include_images: bool = False, keys=None):
"""
Quickly fetch the last frame of each episode in a LeRobotDataset.
- include_images=False: Return only scalar/vector fields from parquet (faster, no video decoding).
- include_images=True : Additionally decode the corresponding image/video frame for the last frame.
- keys: Limit the set of columns to retrieve (default: all non-image/video fields + timestamp, etc.).
Returns: list[dict], where each element contains the last frame info of one episode.
"""
# 1) Compute the global index of the last row for each episode.
# ds.episode_data_index['to'] is the exclusive end index, so last frame = to - 1.
end_idxs = (ds.episode_data_index["to"] - 1).tolist()
# 2) Determine which columns to load.
# By default, exclude video/image columns to avoid triggering slow video decoding.
if keys is None:
non_media_keys = [k for k, ft in ds.features.items() if ft["dtype"] not in ("image", "video")]
keys = list(set(non_media_keys + ["timestamp", "episode_index", "task_index"]))
# 3) Select all last-frame rows at once (does not call __getitem__, so no video decoding is triggered).
last_rows = ds.hf_dataset.select(end_idxs)
# 4) Build a dictionary of tensors for each requested key.
out = []
col = {k: last_rows[k] for k in keys}
# Convert lists of tensors into stacked tensors for easier indexing.
for k, v in col.items():
# datasets.arrow_dataset.Column is the HuggingFace internal type for columns.
if isinstance(v, datasets.arrow_dataset.Column) and len(v) > 0 and hasattr(v[0], "shape"):
col[k] = torch.stack(v[:])
# Iterate through each episodes last frame and build a dict with its values.
for i, ep_end in enumerate(end_idxs):
item = {}
for k in keys:
val = col[k][i]
# Unpack 0-dimensional tensors into Python scalars.
if torch.is_tensor(val) and val.ndim == 0:
val = val.item()
item[k] = val
# Map task_index back to the human-readable task string.
if "task_index" in item:
item["task"] = ds.meta.tasks[int(item["task_index"])]
out.append(item)
# 5) Optionally decode the actual image/video frame for each last timestamp.
if include_images and len(ds.meta.video_keys) > 0:
for i, ep_end in enumerate(end_idxs):
ep_idx = int(out[i]["episode_index"])
ts = float(out[i]["timestamp"])
# Prepare a query dictionary: one timestamp per camera key.
query_ts = {k: [ts] for k in ds.meta.video_keys}
# Decode video frames at the specified timestamps for this episode.
frames = ds._query_videos(query_ts, ep_idx)
# Attach the decoded frame tensors to the output dictionary.
for k, v in frames.items():
out[i][k] = v
return out
if __name__ == "__main__":
# Initialize your dataset (replace with your repo ID or local path).
ds = LeRobotDataset(repo_id="arx_lift2/pick_parcel_20250915")
# Retrieve metadata only (timestamps, states, actions, tasks) without decoding video.
last_infos = get_last_frames(ds, include_images=False)
# Stack all 'observation.state' vectors into a single tensor for further processing.
states = torch.stack([info['observation.state'] for info in last_infos])
# Extract the left-arm joint states (first 7 values of each state vector).
left_arm_states = states[:, 0:7]
mean_state = torch.mean(left_arm_states, dim=0)
std_state = torch.std(left_arm_states, dim=0)
# Print the collected metadata for verification.
print(last_infos)
# --- Run check_final per episode using the last <=50 states ---
EP_ARM_DOFS = 6 # number of left-arm joints we use in check_final
GRIPPER_COL_FULL = -1 # gripper is the last element in the full state vector
STABILITY_WINDOW = 120 # must be consistent with check_final's default
# Determine which episodes to iterate
episode_indices = ds.episodes if ds.episodes is not None else sorted(ds.meta.episodes.keys())
episode_flags = {}
num_true, num_false = 0, 0
for ep_idx in episode_indices:
# Global index range [from_idx, to_idx) for this episode
from_idx = int(ds.episode_data_index["from"][ep_idx])
to_idx = int(ds.episode_data_index["to"][ep_idx])
if to_idx - from_idx <= 0:
episode_flags[ep_idx] = False
num_false += 1
continue
# Take the last <= STABILITY_WINDOW frames from this episode
idxs = list(range(max(from_idx, to_idx - STABILITY_WINDOW), to_idx))
rows = ds.hf_dataset.select(idxs)
# Collect full "observation.state" (shape ~ [W, S])
s_col = rows["observation.state"]
if isinstance(s_col, datasets.arrow_dataset.Column):
S = torch.stack(s_col[:]) # Column -> list[tensor] -> stack
else:
S = torch.stack(s_col) # already a list[tensor]
# Build the 7D small state per frame: first 6 joints + gripper
# (Assumes the gripper signal is at the last position of the full state vector)
small_states = torch.cat([S[:, :EP_ARM_DOFS], S[:, EP_ARM_DOFS:EP_ARM_DOFS+1]], dim=1)
# Run your stopping logic
ok = check_final(
small_states,
arm_dofs=EP_ARM_DOFS,
gripper_index=-1,
stability_window=STABILITY_WINDOW,
)
episode_flags[ep_idx] = bool(ok)
num_true += int(ok)
num_false += int(not ok)
# Summary
total_eps = len(episode_indices)
print(f"[check_final] passed: {num_true} / {total_eps} ({(num_true/max(total_eps,1)):.1%})")
# List some failed episodes for quick inspection
failed_eps = [e for e, passed in episode_flags.items() if not passed]
print("Failed episode indices (first 20):", failed_eps[:20])

View File

@@ -0,0 +1,88 @@
import os
import cv2
from pathlib import Path
from tqdm import tqdm
def extract_last_frame_from_videos(root_dir, output_dir, xx_last_frame=1):
"""
遍历目录找到所有images.rgb.hand_right视频文件提取最后一帧并保存
"""
# 查找所有mp4视频文件
video_files = []
for root, dirs, files in os.walk(root_dir):
for file in files:
if file.endswith('.mp4') and 'observation/head' in root:
video_files.append(os.path.join(root, file))
print(f"找到 {len(video_files)} 个视频文件")
# 处理每个视频文件
for video_path in tqdm(video_files):
try:
# 提取set名称和episode名称
path_parts = Path(video_path).parts
set_name = None
episode_name = None
for part in path_parts:
if part.startswith('set'):
set_name = part
if part.startswith('000'):
episode_name = part.replace('.mp4', '')
if not set_name or not episode_name:
print(f"无法从路径中提取set和episode信息: {video_path}")
continue
# 生成输出文件名
output_filename = f"{set_name}_{episode_name}.jpg"
output_path = os.path.join(output_dir, output_filename)
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"无法打开视频: {video_path}")
continue
# 获取总帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
print(f"视频没有帧: {video_path}")
cap.release()
continue
# 跳转到最后一帧
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - xx_last_frame)
ret, frame = cap.read()
if ret:
# 保存最后一帧
cv2.imwrite(output_path, frame)
print(f"已保存:\n {output_path}")
else:
print(f"无法读取最后一帧: {video_path}")
# 释放资源
cap.release()
except Exception as e:
print(f"处理视频时出错 {video_path}: {str(e)}")
if __name__ == "__main__":
# 指定要遍历的根目录
root_directory = "/home/caijunhao/h-ceph/InternData-A1-raw/arx_lift2/Pick_the_industrial_components_from_the_conveyor" # 当前目录,您可以修改为您的目录路径
output_path = 'data/Pick_the_industrial_components_from_the_conveyor/'
os.makedirs(output_path, exist_ok=True)
sub_list = os.listdir(root_directory)
exclude_list = []
# exclude_list = [f"{i}" for i in range(16)] + [f"{i}" for i in range(26, 29)]
xx_last_frame = 1
# import pdb
# pdb.set_trace()
for sub in tqdm(sub_list):
if sub.split('-')[1].split('_')[0] in exclude_list:
continue
# print("os.path.join([root_directory, sub])\n", os.path.join(root_directory, sub))
extract_last_frame_from_videos(os.path.join(root_directory, sub), output_path, xx_last_frame=xx_last_frame)
print("处理完成!")

View File

@@ -0,0 +1,670 @@
# source /fs-computility/efm/liyang/miniconda3/etc/profile.d/conda.sh
# conda activate act
import argparse
import json
import logging
import os
import gc
import shutil
from concurrent.futures import ALL_COMPLETED, ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
import torchvision
import cv2
import h5py
import lmdb
import numpy as np
import pickle
import torch
from PIL import Image
from scipy.spatial.transform import Rotation
from tqdm import tqdm
import logging
import pdb
import os
import imageio # imageio-ffmpeg
from lerobot.common.datasets.compute_stats import auto_downsample_height_width, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import check_timestamps_sync, get_episode_data_index, validate_episode_buffer
import time
# import ray
# from ray.runtime_env import RuntimeEnv
"""
Store both camera image and robot state as a combined observation.
Args:
observation: images(camera), states (robot state)
actions: joint, gripper, ee_pose
"""
FEATURES = {
"images.rgb.head": {
"dtype": "video",
"shape": (368, 640, 3),
"names": ["height", "width", "channel"],
},
"images.rgb.hand_left": {
"dtype": "video",
"shape": (480, 640, 3),
"names": ["height", "width", "channel"],
},
"images.rgb.hand_right": {
"dtype": "video",
"shape": (480, 640, 3),
"names": ["height", "width", "channel"],
},
# "states.left_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
# },
# "states.left_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["left_gripper_0",],
# },
# "states.right_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
# },
# "states.right_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["right_gripper_0",],
# },
"observation.state": {
"dtype": "float32",
"shape": (14,),
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
},
"action": {
"dtype": "float32",
"shape": (14,),
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
},
# "actions.left_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
# },
# "actions.left_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["left_gripper_0",],
# },
# "actions.right_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
# },
# "actions.right_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["right_gripper_0", ],
# },
}
import numpy as np
def filter_forbidden_frames(state_dict, position_threshold=0.001, velocity_threshold=0.005):
"""
过滤禁止的帧,基于位置和速度阈值
参数:
- state_dict: 形状为 (n, 14) 的状态数组
- position_threshold: 位置变化的阈值
- velocity_threshold: 速度变化的阈值
返回:
- valid_mask: 布尔数组True表示有效帧
"""
# 排除夹爪列第6和第13列索引从0开始
qpos_columns = [i for i in range(14)]
qpos_data = state_dict[:, qpos_columns]
n_frames = len(state_dict)
valid_mask = np.ones(n_frames, dtype=bool)
# import pdb
# pdb.set_trace()
# 计算帧间差异(速度)
if n_frames > 1:
diff_sum = np.sum(np.abs(np.diff(qpos_data, axis=0)), axis=1)
# sorted_indices = np.argsort(diff_sum)[::-1]
# sorted_abs_sums = diff_sum[sorted_indices]
# velocities = np.diff(qpos_data, axis=0)
# 检查速度是否超过阈值
for i in range(n_frames - 1):
if np.any(np.abs(diff_sum[i]) > position_threshold):
valid_mask[i] = True # 有运动,有效帧
else:
valid_mask[i] = False # 静止,可能是禁止帧
valid_mask[i] = True
return valid_mask
def statistical_filter(state_dict, std_multiplier=2.0):
"""
使用统计方法检测异常(禁止)帧
"""
# 排除夹爪列
qpos_columns = [i for i in range(14) if i not in [6, 13]]
qpos_data = state_dict[:, qpos_columns]
# 计算每列的均值和标准差
means = np.mean(qpos_data, axis=0)
stds = np.std(qpos_data, axis=0)
# 创建有效掩码
valid_mask = np.ones(len(state_dict), dtype=bool)
for i in range(len(state_dict)):
# 检查每个关节位置是否在合理范围内
deviations = np.abs(qpos_data[i] - means)
if np.any(deviations > std_multiplier * stds):
valid_mask[i] = False # 异常帧
return valid_mask
class ARXDataset(LeRobotDataset):
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__(
repo_id=repo_id,
root=root,
episodes=episodes,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=tolerance_s,
download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend,
)
def save_episode(self, episode_data: dict | None = None, videos: dict | None = None) -> None:
if not episode_data:
episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
continue
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
for key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
episode_buffer[key] = str(video_path) # PosixPath -> str
video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(videos[key], video_path)
ep_stats = compute_episode_stats(episode_buffer, self.features)
self._save_episode_table(episode_buffer, episode_index)
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync(
episode_buffer["timestamp"],
episode_buffer["episode_index"],
ep_data_index_np,
self.fps,
self.tolerance_s,
)
if not episode_data:
self.episode_buffer = self.create_episode_buffer()
def add_frame(self, frame: dict) -> None:
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
features = {key: value for key, value in self.features.items() if key in self.hf_features}
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
for key in frame:
if key == "task":
self.episode_buffer["task"].append(frame["task"])
continue
if key not in self.features:
print("key ", key)
raise ValueError(f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'.")
# import pdb
# pdb.set_trace()
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
# def crop_resize_no_padding(image, target_size=(480, 640)):
# """
# Crop and scale to target size (no padding)
# :param image: input image (NumPy array)
# :param target_size: target size (height, width)
# :return: processed image
# """
# h, w = image.shape[:2]
# target_h, target_w = target_size
# target_ratio = target_w / target_h # Target aspect ratio (e.g. 640/480=1.333)
# # the original image aspect ratio and cropping direction
# if w / h > target_ratio: # Original image is wider → crop width
# crop_w = int(h * target_ratio) # Calculate crop width based on target aspect ratio
# crop_h = h
# start_x = (w - crop_w) // 2 # Horizontal center starting point
# start_y = 0
# else: # Original image is higher → crop height
# crop_h = int(w / target_ratio) # Calculate clipping height according to target aspect ratio
# crop_w = w
# start_x = 0
# start_y = (h - crop_h) // 2 # Vertical center starting point
# # Perform centered cropping (to prevent out-of-bounds)
# start_x, start_y = max(0, start_x), max(0, start_y)
# end_x, end_y = min(w, start_x + crop_w), min(h, start_y + crop_h)
# cropped = image[start_y:end_y, start_x:end_x]
# # Resize to target size (bilinear interpolation)
# resized = cv2.resize(cropped, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# return resized
def load_lmdb_data(episode_path: Path, sava_path: Path, fps_factor: int, target_fps: int) -> Optional[Dict]:
def load_image(txn, key):
raw = txn.get(key)
data = pickle.loads(raw)
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
# Convert to RGB if necessary
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image = crop_resize_no_padding(image, target_size=(480, 640))
return image
try:
env = lmdb.open(
str(episode_path / "lmdb"),
readonly=True,
lock=False,
max_readers=128,
readahead=False
)
with env.begin(write=False) as txn:
keys = [k for k, _ in txn.cursor()]
image_keys = sorted([k for k in keys if b'head' in k])
if not image_keys:
return None
all_qpos = pickle.loads(txn.get(b'/observations/qpos'))
if np.isscalar(all_qpos):
total_steps = len(image_keys)
all_qpos = [all_qpos] * total_steps
else:
total_steps = len(all_qpos)
all_qpos = np.stack(all_qpos)
state_action_dict = {}
state_action_dict["states.left_joint.position"] = all_qpos[:, :6]
state_action_dict["states.left_gripper.position"] = all_qpos[:, 6][:, None] # np.expand_dims(all_qpos[:, 6], axis=1)
state_action_dict["states.right_joint.position"] = all_qpos[:, 7:13]
state_action_dict["states.right_gripper.position"] = all_qpos[:, 13][:, None] # np.expand_dims(all_qpos[:, 13], axis=1)
# state_keys = list(state_action_dict.keys())
# for k in state_keys:
# state_action_dict[k.replace("states", "actions")] = np.concatenate([state_action_dict[k][1:, :], state_action_dict[k][-1, :][None,:]], axis=0)
# action_dict = {}
# action_dict["actions.left_joint.position"] = np.concatenate([state_dict["states.left_joint.position"][1:, :], state_dict["states.left_joint.position"][-1, :][None,:]], axis=0)
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
# action_dict["actions.right_joint.position"] = state_dict["states.right_joint.position"][1:, :]
# action_dict["actions.right_gripper.position"] = state_dict["states.right_gripper.position"][1:, :]
action_dict = {}
action_dict["action"] = np.concatenate([all_qpos[1:,], all_qpos[-1,].reshape(-1, 14)], axis=0)
state_dict = {}
state_dict["observation.state"] = all_qpos
mask1 = filter_forbidden_frames(state_dict["observation.state"])
# state_dict["observation.state"] = state_dict["observation.state"][mask1]
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
# action_dict["actions.right_arm.position"] = np.concatenate([state_action_dict["states.right_joint.position"][1:, :], state_action_dict["states.right_joint.position"][-1, :][None,:]], axis=0)
# action_dict["actions.left_arm.position"] = state_dict["states.right_gripper.position"][1:, :]
assert total_steps == len(image_keys), "qpos length mismatch"
selected_steps = [step for step in range(total_steps) if step % fps_factor == 0 and mask1[step]]
frames = []
image_observations = {
"images.rgb.head": [],
"images.rgb.hand_left": [],
"images.rgb.hand_right": []
}
start_time = time.time()
for step_index, step in enumerate(selected_steps):
step_str = f"{step:04d}"
head_key = f"observation/head/color_image/{step_str}".encode()
left_key = f"observation/left_wrist/color_image/{step_str}".encode()
right_key = f"observation/right_wrist/color_image/{step_str}".encode()
if not (head_key in keys and left_key in keys and right_key in keys):
continue
# state = all_qpos[step]
# if step_index < len(selected_steps) - 1:
# action = all_qpos[selected_steps[step_index + 1]]
# else:
# action = state
data_dict = {}
# for key, value in state_action_dict.items():
# data_dict[key] = value[step]
data_dict['action'] = action_dict["action"][step]
data_dict["task"] = " ".join(episode_path.parent.parent.name.split("_"))
data_dict['observation.state'] = state_dict["observation.state"][step]
# frames.append({
# "observation.states.joint.position": state,
# "actions.joint.position": action,
# "task": task_name,
# })
frames.append(data_dict)
image_observations["images.rgb.head"].append(load_image(txn, head_key))
image_observations["images.rgb.hand_left"].append(load_image(txn, left_key))
image_observations["images.rgb.hand_right"].append(load_image(txn, right_key))
end_time = time.time()
elapsed_time = end_time - start_time
print(f"load image_observations of {episode_path}")
env.close()
if not frames:
return None
os.makedirs(sava_path, exist_ok=True)
os.makedirs(sava_path/episode_path.name, exist_ok=True)
imageio.mimsave(sava_path/episode_path.name/'head.mp4', image_observations["images.rgb.head"], fps=target_fps)
imageio.mimsave(sava_path/episode_path.name/'hand_left.mp4', image_observations["images.rgb.hand_left"], fps=target_fps)
imageio.mimsave(sava_path/episode_path.name/'hand_right.mp4', image_observations["images.rgb.hand_right"], fps=target_fps)
print(f"imageio.mimsave time taken of {episode_path}")
return {
"frames": frames,
"videos": {
"images.rgb.head": sava_path/episode_path.name/"head.mp4",
"images.rgb.hand_left": sava_path/episode_path.name/"hand_left.mp4",
"images.rgb.hand_right": sava_path/episode_path.name/"hand_right.mp4",
},
}
except Exception as e:
logging.error(f"Failed to load LMDB data: {e}")
return None
def get_all_tasks(src_path: Path, output_path: Path) -> Tuple[Path, Path]:
src_dirs = sorted(list(src_path.glob("*"))) # "set*-*_collector*_datatime" as the conversion unit
save_dirs = [output_path/_dir.parent.name/_dir.name for _dir in src_dirs]
tasks_tuples = zip(src_dirs, save_dirs)
for task in tasks_tuples:
yield task
def compute_episode_stats(episode_data: Dict[str, List[str] | np.ndarray], features: Dict) -> Dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def sample_images(input):
if type(input) is str:
video_path = input
reader = torchvision.io.VideoReader(video_path, stream="video")
frames = [frame["data"] for frame in reader]
frames_array = torch.stack(frames).numpy() # Shape: [T, C, H, W]
sampled_indices = sample_indices(len(frames_array))
images = None
for i, idx in enumerate(sampled_indices):
img = frames_array[idx]
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
elif type(input) is np.ndarray:
frames_array = input[:, None, :, :] # Shape: [T, C, H, W]
sampled_indices = sample_indices(len(frames_array))
images = None
for i, idx in enumerate(sampled_indices):
img = frames_array[idx]
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
return images
def load_local_dataset(episode_path: str, save_path:str, origin_fps=30, target_fps=30):
fps_factor = origin_fps // target_fps
# print(f"fps downsample factor: {fps_factor}")
# logging.info(f"fps downsample factor: {fps_factor}")
# for format_str in [f"{episode_id:07d}", f"{episode_id:06d}", str(episode_id)]:
# episode_path = Path(src_path) / format_str
# save_path = Path(save_path) / format_str
# if episode_path.exists():
# break
# else:
# logging.warning(f"Episode directory not found for ID {episode_id}")
# return None, None
episode_path = Path(episode_path)
if not episode_path.exists():
logging.warning(f"{episode_path} does not exist")
return None, None
if not (episode_path / "lmdb/data.mdb").exists():
logging.warning(f"LMDB data not found for episode {episode_path}")
return None, None
raw_dataset = load_lmdb_data(episode_path, save_path, fps_factor, target_fps)
if raw_dataset is None:
return None, None
frames = raw_dataset["frames"] # states, actions, task
videos = raw_dataset["videos"] # image paths
## check the frames
for camera_name, video_path in videos.items():
if not os.path.exists(video_path):
logging.error(f"Video file {video_path} does not exist.")
print(f"Camera {camera_name} Video file {video_path} does not exist.")
return None, None
return frames, videos
def save_as_lerobot_dataset(task: tuple[Path, Path], repo_id, num_threads, debug, origin_fps=30, target_fps=30, robot_type="piper", delete_downsampled_videos=True):
src_path, save_path = task
print(f"**Processing collected** {src_path}")
print(f"**saving to** {save_path}")
if save_path.exists():
# print(f"Output directory {save_path} already exists. Deleting it.")
# logging.warning(f"Output directory {save_path} already exists. Deleting it.")
# shutil.rmtree(save_path)
print(f"Output directory {save_path} already exists.")
return
dataset = ARXDataset.create(
repo_id=f"{repo_id}",
root=save_path,
fps=target_fps,
robot_type=robot_type,
features=FEATURES,
)
all_episode_paths = sorted([f.as_posix() for f in src_path.glob(f"*") if f.is_dir()])
# all_subdir_eids = [int(Path(path).name) for path in all_subdir]
if debug:
for i in range(1):
# pdb.set_trace()
frames, videos = load_local_dataset(episode_path=all_episode_paths[i], save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
for frame_data in frames:
dataset.add_frame(frame_data)
dataset.save_episode(videos=videos)
if delete_downsampled_videos:
for _, video_path in videos.items():
parent_dir = os.path.dirname(video_path)
try:
shutil.rmtree(parent_dir)
# os.remove(video_path)
# print(f"Successfully deleted: {parent_dir}")
print(f"Successfully deleted: {video_path}")
except Exception as e:
pass # Handle the case where the directory might not exist or is already deleted
else:
for batch_index in range(len(all_episode_paths)//num_threads+1):
batch_episode_paths = all_episode_paths[batch_index*num_threads:(batch_index+1)*num_threads]
if len(batch_episode_paths) == 0:
continue
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for episode_path in batch_episode_paths:
print("starting to process episode: ", episode_path)
futures.append(
executor.submit(load_local_dataset, episode_path=episode_path, save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
)
for raw_dataset in as_completed(futures):
frames, videos = raw_dataset.result()
if frames is None or videos is None:
print(f"Skipping episode {episode_path} due to missing data.")
continue
for frame_data in frames:
dataset.add_frame(frame_data)
dataset.save_episode(videos=videos)
gc.collect()
print(f"finishing processed {videos}")
if delete_downsampled_videos:
for _, video_path in videos.items():
# Get the parent directory of the video
parent_dir = os.path.dirname(video_path)
try:
shutil.rmtree(parent_dir)
print(f"Successfully deleted: {parent_dir}")
except Exception as e:
pass
def main(src_path, save_path, repo_id, num_threads=60, debug=False, origin_fps=30, target_fps=30):
logging.info("Scanning for episodes...")
tasks = get_all_tasks(src_path, save_path)
# import pdb
# pdb.set_trace()
if debug:
task = next(tasks)
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
else:
for task in tasks:
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert collected data from Piper to Lerobot format.")
parser.add_argument(
"--src_path",
type=str,
# required=False,
default="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/",
help="Path to the input file containing collected data in Piper format.",
#help="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Make_a_beef_sandwich",
)
parser.add_argument(
"--save_path",
type=str,
# required=False,
default="/fs-computility/efm/shared/datasets/myData-A1/real/lerobot_v2_1/agilex_split_aloha/",
help="Path to the output file where the converted Lerobot format will be saved.",
#help="Path to the output file where the converted Lerobot format will be saved.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Run in debug mode with limited episodes",
)
parser.add_argument(
"--num-threads",
type=int,
default=50,
help="Number of threads per process",
)
# parser.add_argument(
# "--task_name",
# type=str,
# required=True,
# default="Pick_up_the_marker_and_put_it_into_the_pen_holder",
# help="Name of the task to be processed. Default is 'Pick_up_the_marker_and_put_it_into_the_pen_holder'.",
# )
parser.add_argument(
"--repo_id",
type=str,
required=True,
# default="SplitAloha_20250714",
help="identifier for the dataset repository.",
)
parser.add_argument(
"--origin_fps",
type=int,
default=30,
help="Frames per second for the obervation video. Default is 30.",
)
parser.add_argument(
"--target_fps",
type=int,
default=30,
help="Frames per second for the downsample video. Default is 30.",
)
args = parser.parse_args()
assert int(args.origin_fps) % int(args.target_fps) == 0, "origin_fps must be an integer multiple of target_fps"
start_time = time.time()
main(
src_path=Path(args.src_path),
save_path=Path(args.save_path),
repo_id=args.repo_id,
num_threads=args.num_threads,
debug=args.debug,
origin_fps=args.origin_fps,
target_fps=args.target_fps
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Total time taken: {elapsed_time:.2f} seconds")
# --target_fps 10
# --src_path /fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Put_the_bananas_in_the_basket
# --save_path /mnt/shared-storage-user/internvla/Users/liyang/data/processed_data/arx_lift2

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,587 @@
#!/usr/bin/env python3
"""
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
This script loads a JAX model checkpoint using orbax and can either:
1. Print out all the parameter keys in a hierarchical structure for inspection
2. Convert the JAX model to PyTorch format using our PI0Pytorch model
Usage:
# Just inspect keys:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
# Convert to PyTorch:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
Example:
# pi0_droid
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
# pi0_aloha_sim
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
# pi05_droid
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
"""
import json
import os
import pathlib
import shutil
from typing import Literal
from flax.nnx import traversals
import numpy as np
import orbax.checkpoint as ocp
import safetensors
import torch
import tyro
import openpi.models.gemma
import openpi.models.model
import openpi.models.pi0_config
import openpi.models_pytorch.pi0_pytorch
from openpi.training import utils
import openpi.training.config as _config
def slice_paligemma_state_dict(state_dict, config):
"""Convert PaliGemma JAX parameters to PyTorch format."""
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# patch embeddings
jax_key = f"img/embedding/kernel{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
jax_key = f"img/embedding/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# positional embeddings
jax_key = f"img/pos_embedding{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
)
encoderblock_attention_0_key_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
)
encoderblock_attention_0_value_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
)
encoderblock_attention_0_value_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
)
encoderblock_attention_0_query_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
)
encoderblock_attention_0_query_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
)
encoderblock_attention_0_out_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
)
encoderblock_attention_0_out_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
)
for i in range(config.vision_config.num_hidden_layers):
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
] = encoderblock_layernorm0_scale[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
] = encoderblock_layernorm0_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
] = encoderblock_layernorm1_scale[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
] = encoderblock_layernorm1_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
] = encoderblock_mlp_dense0_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
] = encoderblock_mlp_dense1_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# multimodal projector
jax_key = f"img/head/kernel{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
jax_key = f"img/head/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# text decoder (gemma)
jax_key = f"llm/embedder/input_embedding{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# pop the einsum attention + mlp representations
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
q_proj_weight_reshaped = (
llm_attention_q_einsum[i]
.transpose(0, 2, 1)
.reshape(
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
)
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
q_proj_weight_reshaped
)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
k_proj_weight_reshaped
)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
v_proj_weight_reshaped
)
o_proj_weight_reshaped = (
llm_attention_attn_vec_einsum[i]
.transpose(2, 0, 1)
.reshape(
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
)
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
o_proj_weight_reshaped
)
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
gate_proj_weight.transpose()
)
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
up_proj_weight.transpose()
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
llm_mlp_linear[i].transpose()
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
llm_input_layernorm[i]
)
state_dict[
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
] = llm_post_attention_layernorm[i]
jax_key = f"llm/final_norm/scale{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key)
expert_dict = {}
final_state_dict = {}
# Expert-related keys to extract (including pi05 Dense layer parameters)
expert_keys = [
f"llm/final_norm_1/scale{suffix}",
f"llm/final_norm_1/Dense_0/bias{suffix}",
f"llm/final_norm_1/Dense_0/kernel{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
]
for key, value in state_dict.items():
if key not in expert_keys:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
"""Convert Gemma JAX parameters to PyTorch format."""
# Add missing attributes to config if they don't exist
if not hasattr(config, "vocab_size"):
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
if not hasattr(config, "hidden_size"):
config.hidden_size = config.width
if not hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = config.depth
if not hasattr(config, "num_attention_heads"):
config.num_attention_heads = config.num_heads
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
llm_input_layernorm_kernel = state_dict.pop(
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
)
llm_post_attention_layernorm_kernel = state_dict.pop(
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
)
else:
# Regular pi0 with standard RMSNorm
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = (
llm_attention_q_einsum[i]
.transpose(0, 2, 1)
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
q_proj_weight_reshaped
)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
k_proj_weight_reshaped
)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
v_proj_weight_reshaped
)
o_proj_weight_reshaped = (
llm_attention_attn_vec_einsum[i]
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
.transpose(1, 0)
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
o_proj_weight_reshaped
)
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
gate_proj_weight.transpose()
)
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
up_proj_weight.transpose()
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
i
].transpose()
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization - use Dense layer parameters directly
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
llm_input_layernorm_bias[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
llm_post_attention_layernorm_bias[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
llm_input_layernorm_kernel[i].transpose()
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
llm_post_attention_layernorm_kernel[i].transpose()
)
else:
# Regular pi0 with standard RMSNorm
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
llm_input_layernorm[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
llm_post_attention_layernorm[i]
)
# Handle final norm layer
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization - use Dense layer parameters directly
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
else:
# Regular pi0 with standard RMSNorm
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
f"llm/final_norm_{num_expert}/scale{suffix}"
)
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
"""Load and process params by restoring via JAX model loader first.
This respects dtype conversions that occur during model restore.
"""
# Use repository restore utility to load a pure dict of params (value suffix removed)
params = openpi.models.model.restore_params(
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
)
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
def load_jax_model_and_print_keys(checkpoint_dir: str):
"""
Load JAX model from checkpoint and print all parameter keys.
Args:
checkpoint_dir: Path to the checkpoint directory
"""
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
# Initialize checkpointer
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
print(utils.array_tree_to_info(metadata))
def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
):
"""
Convert PI0 JAX checkpoint to PyTorch format.
Args:
checkpoint_dir: Path to the JAX checkpoint
precision: Model precision (float32, bfloat16, float16)
output_path: Path to save the converted PyTorch model
model_config: Model config
"""
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
print(f"Model config: {model_config}")
# Break down orbax ckpts by restoring via JAX to respect dtype
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
# Process projection params
if model_config.pi05:
keys = [
"action_in_proj",
"action_out_proj",
"time_mlp_in",
"time_mlp_out",
]
else:
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
pytorch_weight_key = f"{key}.weight"
pytorch_bias_key = f"{key}.bias"
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
# Create configs based on checkpoint path
# All models use the same PaliGemma config structure
class PaliGemmaConfig:
def __init__(self):
self.vision_config = type(
"obj",
(object,),
{
"hidden_size": 1152,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"intermediate_size": 4304,
"patch_size": 14,
"projection_dim": 2048,
},
)()
self.text_config = type(
"obj",
(object,),
{
"hidden_size": 2048,
"num_hidden_layers": 18,
"num_attention_heads": 8,
"head_dim": 256,
"intermediate_size": 16384,
},
)()
paligemma_config = PaliGemmaConfig()
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
# Process PaliGemma weights
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
# Process Gemma weights from expert_params
gemma_params = slice_gemma_state_dict(
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
)
# Instantiate model
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
# Combine all parameters (no prefix needed for our model structure)
all_params = {**paligemma_params, **gemma_params, **projection_params}
# Load state dict
pi0_model.load_state_dict(all_params, strict=False)
if precision == "float32":
pi0_model = pi0_model.to(torch.float32)
elif precision == "bfloat16":
pi0_model = pi0_model.to(torch.bfloat16)
else:
raise ValueError(f"Invalid precision: {precision}")
# Save the converted model using safetensors
os.makedirs(output_path, exist_ok=True)
# Save model weights as SafeTensors using save_model to handle tied weights
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
# Copy assets folder if it exists
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
if assets_source.exists():
assets_dest = pathlib.Path(output_path) / "assets"
if assets_dest.exists():
shutil.rmtree(assets_dest)
shutil.copytree(assets_source, assets_dest)
# Save config as JSON for reference
config_dict = {
"action_dim": model_config.action_dim,
"action_horizon": model_config.action_horizon,
"paligemma_variant": model_config.paligemma_variant,
"action_expert_variant": model_config.action_expert_variant,
"precision": precision,
}
with open(os.path.join(output_path, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
print("Model conversion completed successfully!")
print(f"Model saved to {output_path}")
def main(
checkpoint_dir: str,
config_name: str,
output_path: str | None = None,
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
*,
inspect_only: bool = False,
):
"""Load JAX model and optionally convert to PyTorch.
Args:
checkpoint_dir: Path to the JAX checkpoint directory
output_path: Path to save converted PyTorch model (required for conversion)
precision: Precision for model conversion
inspect_only: Only inspect parameter keys, don't convert
"""
model_config = _config.get_config(config_name).model
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
raise ValueError(f"Config {config_name} is not a Pi0Config")
if inspect_only:
load_jax_model_and_print_keys(checkpoint_dir)
else:
if not output_path:
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
return
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,84 @@
# DROID Policies in openpi
We offer instructions for:
- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
## Running DROID Inference
This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
### Step 1: Start a policy server
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
2. Start the OpenPI server via the following command:
```bash
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
```
You can also run the equivalent command below:
```bash
uv run scripts/serve_policy.py --env=DROID
```
### Step 2: Run the DROID robot
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
2. On the control laptop, activate your DROID conda environment.
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
```bash
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
```
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
## Troubleshooting
| Issue | Solution |
|-------|----------|
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
## Running Other Policies
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
```
# Train from pi0-FAST, using FAST tokenizer
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
# Train from pi0, using flow matching
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
# Trained from PaliGemma, using FSQ tokenizer.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
```
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).

View File

@@ -0,0 +1,106 @@
# Training on DROID
Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
## Install
We need a few additional dependencies for RLDS data loading. Run:
```bash
uv sync --group rlds
```
## Download DROID dataset
You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
```
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
```
Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
You will need 1.8TB of disk storage to download the DROID RLDS dataset.
## Run
First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
Then, compute normalization statistics (this will take ~10 minutes):
```bash
uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
```
Run training:
```bash
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
```
**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
## Compute Requirements
Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
## Data Filtering
Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
## RoboArena
Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
# Fine-Tuning on Custom DROID Datasets
Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
## Step 1: Converting your custom DROID dataset to LeRobot
We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
```
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
```
We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
```
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
```
For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
```
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
```
## Step 2: Run fine-tuning with your custom dataset
Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
To launch training:
```
uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
```
Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.

View File

@@ -0,0 +1,103 @@
"""
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
that should be sampled during training (all others are filtered out).
Filtering logic:
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
"""
import json
import os
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
builder = tfds.builder_from_directory(
# path to the `droid` directory (not its parent)
builder_dir="<path_to_droid_dataset_tfds_files>",
)
ds = builder.as_dataset(split="train", shuffle_files=False)
tf.data.experimental.ignore_errors(ds)
keep_ranges_path = "<path_to_where_to_save_the_json>"
min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
keep_ranges_map = {}
if Path(keep_ranges_path).exists():
with Path(keep_ranges_path).open("r") as f:
keep_ranges_map = json.load(f)
print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
for ep_idx, ep in enumerate(tqdm(ds)):
recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
file_path = ep["episode_metadata"]["file_path"].numpy().decode()
key = f"{recording_folderpath}--{file_path}"
if key in keep_ranges_map:
continue
joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
joint_velocities = np.array(joint_velocities)
is_idle_array = np.hstack(
[np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
)
# Find what steps go from idle to non-idle and vice-versa
is_idle_padded = np.concatenate(
[[False], is_idle_array, [False]]
) # Start and end with False, so idle at first step is a start of motion
is_idle_diff = np.diff(is_idle_padded.astype(int))
is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
# Find which steps correspond to idle segments of length at least min_idle_len
true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
is_idle_true_starts = is_idle_true_starts[true_segment_masks]
is_idle_true_ends = is_idle_true_ends[true_segment_masks]
keep_mask = np.ones(len(joint_velocities), dtype=bool)
for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
keep_mask[start:end] = False
# Get all non-idle ranges of at least 16
# Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
keep_padded = np.concatenate([[False], keep_mask, [False]])
keep_diff = np.diff(keep_padded.astype(int))
keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
# Find which steps correspond to non-idle segments of length at least min_non_idle_len
true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
keep_true_starts = keep_true_starts[true_segment_masks]
keep_true_ends = keep_true_ends[true_segment_masks]
# Add mapping from episode unique ID key to list of non-idle ranges to keep
keep_ranges_map[key] = []
for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
if ep_idx % 1000 == 0:
with Path(keep_ranges_path).open("w") as f:
json.dump(keep_ranges_map, f)
print("Done!")
with Path(keep_ranges_path).open("w") as f:
json.dump(keep_ranges_map, f)

View File

@@ -0,0 +1,477 @@
"""
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
Usage:
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
The resulting dataset will get saved to the $LEROBOT_HOME directory.
"""
from collections import defaultdict
import copy
import glob
import json
from pathlib import Path
import shutil
import cv2
import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
from PIL import Image
from tqdm import tqdm
import tyro
REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
def resize_image(image, size):
image = Image.fromarray(image)
return np.array(image.resize(size, resample=Image.BICUBIC))
def main(data_dir: str, *, push_to_hub: bool = False):
# Clean up any existing dataset in the output directory
output_path = HF_LEROBOT_HOME / REPO_NAME
if output_path.exists():
shutil.rmtree(output_path)
data_dir = Path(data_dir)
# Create LeRobot dataset, define features to store
# We will follow the DROID data naming conventions here.
# LeRobot assumes that dtype of image data is `image`
dataset = LeRobotDataset.create(
repo_id=REPO_NAME,
robot_type="panda",
fps=15, # DROID data is typically recorded at 15fps
features={
# We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
"exterior_image_1_left": {
"dtype": "image",
"shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
"names": ["height", "width", "channel"],
},
"exterior_image_2_left": {
"dtype": "image",
"shape": (180, 320, 3),
"names": ["height", "width", "channel"],
},
"wrist_image_left": {
"dtype": "image",
"shape": (180, 320, 3),
"names": ["height", "width", "channel"],
},
"joint_position": {
"dtype": "float32",
"shape": (7,),
"names": ["joint_position"],
},
"gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": ["gripper_position"],
},
"actions": {
"dtype": "float32",
"shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
"names": ["actions"],
},
},
image_writer_threads=10,
image_writer_processes=5,
)
# Load language annotations
# Note: we load the DROID language annotations for this example, but you can manually define them for your own data
with (data_dir / "aggregated-annotations-030724.json").open() as f:
language_annotations = json.load(f)
# Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
# We assume the following directory structure:
# RAW_DROID_PATH/
# - <...>/
# - recordings/
# - MP4/
# - <camera_id>.mp4 # single-view video of left stereo pair camera
# - trajectory.hdf5
# - <...>/
episode_paths = list(data_dir.glob("**/trajectory.h5"))
print(f"Found {len(episode_paths)} episodes for conversion")
# We will loop over each dataset_name and write episodes to the LeRobot dataset
for episode_path in tqdm(episode_paths, desc="Converting episodes"):
# Load raw data
recording_folderpath = episode_path.parent / "recordings" / "MP4"
trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
# To load the language instruction, we need to parse out the episode_id from the metadata file
# Again, you can modify this step for your own data, to load your own language instructions
metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
"language_instruction1"
]
print(f"Converting episode with language instruction: {language_instruction}")
# Write to LeRobot dataset
for step in trajectory:
camera_type_dict = step["observation"]["camera_type"]
wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
dataset.add_frame(
{
# Note: need to flip BGR --> RGB for loaded images
"exterior_image_1_left": resize_image(
step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
),
"exterior_image_2_left": resize_image(
step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
),
"wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
"joint_position": np.asarray(
step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
),
"gripper_position": np.asarray(
step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
),
# Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
"actions": np.concatenate(
[step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
),
"task": language_instruction,
}
)
dataset.save_episode()
# Optionally push to the Hugging Face Hub
if push_to_hub:
dataset.push_to_hub(
tags=["libero", "panda", "rlds"],
private=False,
push_videos=True,
license="apache-2.0",
)
##########################################################################################################
################ The rest of this file are functions to parse the raw DROID data #########################
################ You don't need to worry about understanding this part #########################
################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
##########################################################################################################
camera_type_dict = {
"hand_camera_id": 0,
"varied_camera_1_id": 1,
"varied_camera_2_id": 1,
}
camera_type_to_string_dict = {
0: "hand_camera",
1: "varied_camera",
2: "fixed_camera",
}
def get_camera_type(cam_id):
if cam_id not in camera_type_dict:
return None
type_int = camera_type_dict[cam_id]
return camera_type_to_string_dict[type_int]
class MP4Reader:
def __init__(self, filepath, serial_number):
# Save Parameters #
self.serial_number = serial_number
self._index = 0
# Open Video Reader #
self._mp4_reader = cv2.VideoCapture(filepath)
if not self._mp4_reader.isOpened():
raise RuntimeError("Corrupted MP4 File")
def set_reading_parameters(
self,
image=True, # noqa: FBT002
concatenate_images=False, # noqa: FBT002
resolution=(0, 0),
resize_func=None,
):
# Save Parameters #
self.image = image
self.concatenate_images = concatenate_images
self.resolution = resolution
self.resize_func = cv2.resize
self.skip_reading = not image
if self.skip_reading:
return
def get_frame_resolution(self):
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
return (width, height)
def get_frame_count(self):
if self.skip_reading:
return 0
return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
def set_frame_index(self, index):
if self.skip_reading:
return
if index < self._index:
self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
self._index = index
while self._index < index:
self.read_camera(ignore_data=True)
def _process_frame(self, frame):
frame = copy.deepcopy(frame)
if self.resolution == (0, 0):
return frame
return self.resize_func(frame, self.resolution)
def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
# Skip if Read Unnecesary #
if self.skip_reading:
return {}
# Read Camera #
success, frame = self._mp4_reader.read()
self._index += 1
if not success:
return None
if ignore_data:
return None
# Return Data #
data_dict = {}
if self.concatenate_images or "stereo" not in self.serial_number:
data_dict["image"] = {self.serial_number: self._process_frame(frame)}
else:
single_width = frame.shape[1] // 2
data_dict["image"] = {
self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
}
return data_dict
def disable_camera(self):
if hasattr(self, "_mp4_reader"):
self._mp4_reader.release()
class RecordedMultiCameraWrapper:
def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
# Save Camera Info #
self.camera_kwargs = camera_kwargs
# Open Camera Readers #
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
all_filepaths = mp4_filepaths
self.camera_dict = {}
for f in all_filepaths:
serial_number = f.split("/")[-1][:-4]
cam_type = get_camera_type(serial_number)
camera_kwargs.get(cam_type, {})
if f.endswith(".mp4"):
Reader = MP4Reader # noqa: N806
else:
raise ValueError
self.camera_dict[serial_number] = Reader(f, serial_number)
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
full_obs_dict = defaultdict(dict)
# Read Cameras In Randomized Order #
all_cam_ids = list(self.camera_dict.keys())
# random.shuffle(all_cam_ids)
for cam_id in all_cam_ids:
if "stereo" in cam_id:
continue
try:
cam_type = camera_type_dict[cam_id]
except KeyError:
print(f"{self.camera_dict} -- {camera_type_dict}")
raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
if index is not None:
self.camera_dict[cam_id].set_frame_index(index)
data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
# Process Returned Data #
if data_dict is None:
return None
for key in data_dict:
full_obs_dict[key].update(data_dict[key])
return full_obs_dict
def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
length = None
for key in hdf5_file:
if key in keys_to_ignore:
continue
curr_data = hdf5_file[key]
if isinstance(curr_data, h5py.Group):
curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
elif isinstance(curr_data, h5py.Dataset):
curr_length = len(curr_data)
else:
raise ValueError
if length is None:
length = curr_length
assert curr_length == length
return length
def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
data_dict = {}
for key in hdf5_file:
if key in keys_to_ignore:
continue
curr_data = hdf5_file[key]
if isinstance(curr_data, h5py.Group):
data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
elif isinstance(curr_data, h5py.Dataset):
data_dict[key] = curr_data[index]
else:
raise ValueError
return data_dict
class TrajectoryReader:
def __init__(self, filepath, read_images=True): # noqa: FBT002
self._hdf5_file = h5py.File(filepath, "r")
is_video_folder = "observations/videos" in self._hdf5_file
self._read_images = read_images and is_video_folder
self._length = get_hdf5_length(self._hdf5_file)
self._video_readers = {}
self._index = 0
def length(self):
return self._length
def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
# Make Sure We Read Within Range #
if index is None:
index = self._index
else:
assert not self._read_images
self._index = index
assert index < self._length
# Load Low Dimensional Data #
keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
# Increment Read Index #
self._index += 1
# Return Timestep #
return timestep
def close(self):
self._hdf5_file.close()
def load_trajectory(
filepath=None,
read_cameras=True, # noqa: FBT002
recording_folderpath=None,
camera_kwargs={}, # noqa: B006
remove_skipped_steps=False, # noqa: FBT002
num_samples_per_traj=None,
num_samples_per_traj_coeff=1.5,
):
read_recording_folderpath = read_cameras and (recording_folderpath is not None)
traj_reader = TrajectoryReader(filepath)
if read_recording_folderpath:
camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
horizon = traj_reader.length()
timestep_list = []
# Choose Timesteps To Save #
if num_samples_per_traj:
num_to_save = num_samples_per_traj
if remove_skipped_steps:
num_to_save = int(num_to_save * num_samples_per_traj_coeff)
max_size = min(num_to_save, horizon)
indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
else:
indices_to_save = np.arange(horizon)
# Iterate Over Trajectory #
for i in indices_to_save:
# Get HDF5 Data #
timestep = traj_reader.read_timestep(index=i)
# If Applicable, Get Recorded Data #
if read_recording_folderpath:
timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
camera_type_dict = {
k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
}
camera_obs = camera_reader.read_cameras(
index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
)
camera_failed = camera_obs is None
# Add Data To Timestep If Successful #
if camera_failed:
break
timestep["observation"].update(camera_obs)
# Filter Steps #
step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
delete_skipped_step = step_skipped and remove_skipped_steps
# Save Filtered Timesteps #
if delete_skipped_step:
del timestep
else:
timestep_list.append(timestep)
# Remove Extra Transitions #
timestep_list = np.array(timestep_list)
if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
timestep_list = timestep_list[ind_to_keep]
# Close Readers #
traj_reader.close()
# Return Data #
return timestep_list
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,246 @@
# ruff: noqa
import contextlib
import dataclasses
import datetime
import faulthandler
import os
import signal
import time
from moviepy.editor import ImageSequenceClip
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy
import pandas as pd
from PIL import Image
from droid.robot_env import RobotEnv
import tqdm
import tyro
faulthandler.enable()
# DROID data collection frequency -- we slow down execution to match this frequency
DROID_CONTROL_FREQUENCY = 15
@dataclasses.dataclass
class Args:
# Hardware parameters
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
# Policy parameters
external_camera: str | None = (
None # which external camera should be fed to the policy, choose from ["left", "right"]
)
# Rollout parameters
max_timesteps: int = 600
# How many actions to execute from a predicted action chunk before querying policy server again
# 8 is usually a good default (equals 0.5 seconds of action execution).
open_loop_horizon: int = 8
# Remote server parameters
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
remote_port: int = (
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
)
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
# waiting for a new action chunk, it will raise an exception and the server connection dies.
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
@contextlib.contextmanager
def prevent_keyboard_interrupt():
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
interrupted = False
original_handler = signal.getsignal(signal.SIGINT)
def handler(signum, frame):
nonlocal interrupted
interrupted = True
signal.signal(signal.SIGINT, handler)
try:
yield
finally:
signal.signal(signal.SIGINT, original_handler)
if interrupted:
raise KeyboardInterrupt
def main(args: Args):
# Make sure external camera is specified by user -- we only use one external camera for the policy
assert (
args.external_camera is not None and args.external_camera in ["left", "right"]
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
print("Created the droid env!")
# Connect to the policy server
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
while True:
instruction = input("Enter instruction: ")
# Rollout parameters
actions_from_chunk_completed = 0
pred_action_chunk = None
# Prepare to save video of rollout
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
video = []
bar = tqdm.tqdm(range(args.max_timesteps))
print("Running rollout... press Ctrl+C to stop early.")
for t_step in bar:
start_time = time.time()
try:
# Get the current observation
curr_obs = _extract_observation(
args,
env.get_observation(),
# Save the first observation to disk
save_to_disk=t_step == 0,
)
video.append(curr_obs[f"{args.external_camera}_image"])
# Send websocket request to policy server if it's time to predict a new chunk
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
actions_from_chunk_completed = 0
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
# and improve latency.
request_data = {
"observation/exterior_image_1_left": image_tools.resize_with_pad(
curr_obs[f"{args.external_camera}_image"], 224, 224
),
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
"observation/joint_position": curr_obs["joint_position"],
"observation/gripper_position": curr_obs["gripper_position"],
"prompt": instruction,
}
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
# Ctrl+C will be handled after the server call is complete
with prevent_keyboard_interrupt():
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
pred_action_chunk = policy_client.infer(request_data)["actions"]
assert pred_action_chunk.shape == (10, 8)
# Select current action to execute from chunk
action = pred_action_chunk[actions_from_chunk_completed]
actions_from_chunk_completed += 1
# Binarize gripper action
if action[-1].item() > 0.5:
# action[-1] = 1.0
action = np.concatenate([action[:-1], np.ones((1,))])
else:
# action[-1] = 0.0
action = np.concatenate([action[:-1], np.zeros((1,))])
# clip all dimensions of action to [-1, 1]
action = np.clip(action, -1, 1)
env.step(action)
# Sleep to match DROID data collection frequency
elapsed_time = time.time() - start_time
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
except KeyboardInterrupt:
break
video = np.stack(video)
save_filename = "video_" + timestamp
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
success: str | float | None = None
while not isinstance(success, float):
success = input(
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
)
if success == "y":
success = 1.0
elif success == "n":
success = 0.0
success = float(success) / 100
if not (0 <= success <= 1):
print(f"Success must be a number in [0, 100] but got: {success * 100}")
df = df.append(
{
"success": success,
"duration": t_step,
"video_filename": save_filename,
},
ignore_index=True,
)
if input("Do one more eval? (enter y or n) ").lower() != "y":
break
env.reset()
os.makedirs("results", exist_ok=True)
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
df.to_csv(csv_filename)
print(f"Results saved to {csv_filename}")
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
image_observations = obs_dict["image"]
left_image, right_image, wrist_image = None, None, None
for key in image_observations:
# Note the "left" below refers to the left camera in the stereo pair.
# The model is only trained on left stereo cams, so we only feed those.
if args.left_camera_id in key and "left" in key:
left_image = image_observations[key]
elif args.right_camera_id in key and "left" in key:
right_image = image_observations[key]
elif args.wrist_camera_id in key and "left" in key:
wrist_image = image_observations[key]
# Drop the alpha dimension
left_image = left_image[..., :3]
right_image = right_image[..., :3]
wrist_image = wrist_image[..., :3]
# Convert to RGB
left_image = left_image[..., ::-1]
right_image = right_image[..., ::-1]
wrist_image = wrist_image[..., ::-1]
# In addition to image observations, also capture the proprioceptive state
robot_state = obs_dict["robot_state"]
cartesian_position = np.array(robot_state["cartesian_position"])
joint_position = np.array(robot_state["joint_positions"])
gripper_position = np.array([robot_state["gripper_position"]])
# Save the images to disk so that they can be viewed live while the robot is running
# Create one combined image to make live viewing easy
if save_to_disk:
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
combined_image = Image.fromarray(combined_image)
combined_image.save("robot_camera_views.png")
return {
"left_image": left_image,
"right_image": right_image,
"wrist_image": wrist_image,
"cartesian_position": cartesian_position,
"joint_position": joint_position,
"gripper_position": gripper_position,
}
if __name__ == "__main__":
args: Args = tyro.cli(Args)
main(args)

View File

@@ -0,0 +1,137 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import dataclasses\n",
"\n",
"import jax\n",
"\n",
"from openpi.models import model as _model\n",
"from openpi.policies import droid_policy\n",
"from openpi.policies import policy_config as _policy_config\n",
"from openpi.shared import download\n",
"from openpi.training import config as _config\n",
"from openpi.training import data_loader as _data_loader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Policy inference\n",
"\n",
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_fast_droid\")\n",
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
"\n",
"# Create a trained policy.\n",
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
"\n",
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
"example = droid_policy.make_droid_example()\n",
"result = policy.infer(example)\n",
"\n",
"# Delete the policy to free up memory.\n",
"del policy\n",
"\n",
"print(\"Actions shape:\", result[\"actions\"].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Working with a live model\n",
"\n",
"\n",
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_aloha_sim\")\n",
"\n",
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
"key = jax.random.key(0)\n",
"\n",
"# Create a model from the checkpoint.\n",
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
"\n",
"# We can create fake observations and actions to test the model.\n",
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"print(\"Loss shape:\", loss.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Reduce the batch size to reduce memory usage.\n",
"config = dataclasses.replace(config, batch_size=2)\n",
"\n",
"# Load a single batch of data. This is the same data that will be used during training.\n",
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
"obs, act = next(iter(loader))\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"\n",
"# Delete the model to free up memory.\n",
"del model\n",
"\n",
"print(\"Loss shape:\", loss.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,59 @@
# Dockerfile for the LIBERO benchmark.
# Build the container:
# docker build . -t libero -f examples/libero/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
make \
g++ \
clang \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
# Create a default config file to avoid an input prompt from LIBERO's init script.
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
ENV LIBERO_CONFIG_PATH=/tmp/libero
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
benchmark_root: /app/third_party/libero/libero/libero
bddl_files: /app/third_party/libero/libero/libero/bddl_files
init_states: /app/third_party/libero/libero/libero/init_files
datasets: /app/third_party/libero/libero/datasets
assets: /app/third_party/libero/libero/libero/assets
EOF
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]

View File

@@ -0,0 +1,71 @@
# LIBERO Benchmark
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
This example requires git submodules to be initialized. Don't forget to run:
```bash
git submodule update --init --recursive
```
## With Docker (recommended)
```bash
# Grant access to the X11 server:
sudo xhost +local:docker
# To run with the default checkpoint and task suite:
SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
# To run with glx for Mujoco instead (use this if you have egl errors):
MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
```
You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
For example:
```bash
# To load a custom checkpoint (located in the top-level openpi/ directory):
export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
# To run the libero_10 task suite:
export CLIENT_ARGS="--args.task-suite-name libero_10"
```
## Without Docker (not recommended)
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.8 examples/libero/.venv
source examples/libero/.venv/bin/activate
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
uv pip install -e packages/openpi-client
uv pip install -e third_party/libero
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
# Run the simulation
python examples/libero/main.py
# To run with glx for Mujoco instead (use this if you have egl errors):
MUJOCO_GL=glx python examples/libero/main.py
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env LIBERO
```
## Results
If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
checkpoint was trained in openpi with the `pi05_libero` config.
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|-------|---------------|---------------|-------------|-----------|---------|
| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85

View File

@@ -0,0 +1,54 @@
# Run with:
# docker compose -f examples/libero/compose.yml up --build
services:
runtime:
image: libero
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/libero/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
- /tmp/.X11-unix:/tmp/.X11-unix:ro
environment:
- CLIENT_ARGS
- DISPLAY=$DISPLAY
- MUJOCO_GL=${MUJOCO_GL:-egl}
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,104 @@
"""
Minimal example script for converting a dataset to LeRobot format.
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
modified for any other data you have saved in a custom format.
Usage:
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
Note: to run the script, you need to install tensorflow_datasets:
`uv pip install tensorflow tensorflow_datasets`
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
Running this conversion script will take approximately 30 minutes.
"""
import shutil
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import tensorflow_datasets as tfds
import tyro
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
RAW_DATASET_NAMES = [
"libero_10_no_noops",
"libero_goal_no_noops",
"libero_object_no_noops",
"libero_spatial_no_noops",
] # For simplicity we will combine multiple Libero datasets into one training dataset
def main(data_dir: str, *, push_to_hub: bool = False):
# Clean up any existing dataset in the output directory
output_path = HF_LEROBOT_HOME / REPO_NAME
if output_path.exists():
shutil.rmtree(output_path)
# Create LeRobot dataset, define features to store
# OpenPi assumes that proprio is stored in `state` and actions in `action`
# LeRobot assumes that dtype of image data is `image`
dataset = LeRobotDataset.create(
repo_id=REPO_NAME,
robot_type="panda",
fps=10,
features={
"image": {
"dtype": "image",
"shape": (256, 256, 3),
"names": ["height", "width", "channel"],
},
"wrist_image": {
"dtype": "image",
"shape": (256, 256, 3),
"names": ["height", "width", "channel"],
},
"state": {
"dtype": "float32",
"shape": (8,),
"names": ["state"],
},
"actions": {
"dtype": "float32",
"shape": (7,),
"names": ["actions"],
},
},
image_writer_threads=10,
image_writer_processes=5,
)
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
# You can modify this for your own data format
for raw_dataset_name in RAW_DATASET_NAMES:
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
for episode in raw_dataset:
for step in episode["steps"].as_numpy_iterator():
dataset.add_frame(
{
"image": step["observation"]["image"],
"wrist_image": step["observation"]["wrist_image"],
"state": step["observation"]["state"],
"actions": step["action"],
"task": step["language_instruction"].decode(),
}
)
dataset.save_episode()
# Optionally push to the Hugging Face Hub
if push_to_hub:
dataset.push_to_hub(
tags=["libero", "panda", "rlds"],
private=False,
push_videos=True,
license="apache-2.0",
)
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,219 @@
import collections
import dataclasses
import logging
import math
import pathlib
import imageio
from libero.libero import benchmark
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy
import tqdm
import tyro
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
@dataclasses.dataclass
class Args:
#################################################################################################################
# Model server parameters
#################################################################################################################
host: str = "0.0.0.0"
port: int = 8000
resize_size: int = 224
replan_steps: int = 5
#################################################################################################################
# LIBERO environment-specific parameters
#################################################################################################################
task_suite_name: str = (
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
)
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
num_trials_per_task: int = 50 # Number of rollouts per task
#################################################################################################################
# Utils
#################################################################################################################
video_out_path: str = "data/libero/videos" # Path to save videos
seed: int = 7 # Random Seed (for reproducibility)
def eval_libero(args: Args) -> None:
# Set random seed
np.random.seed(args.seed)
# Initialize LIBERO task suite
benchmark_dict = benchmark.get_benchmark_dict()
task_suite = benchmark_dict[args.task_suite_name]()
num_tasks_in_suite = task_suite.n_tasks
logging.info(f"Task suite: {args.task_suite_name}")
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
if args.task_suite_name == "libero_spatial":
max_steps = 220 # longest training demo has 193 steps
elif args.task_suite_name == "libero_object":
max_steps = 280 # longest training demo has 254 steps
elif args.task_suite_name == "libero_goal":
max_steps = 300 # longest training demo has 270 steps
elif args.task_suite_name == "libero_10":
max_steps = 520 # longest training demo has 505 steps
elif args.task_suite_name == "libero_90":
max_steps = 400 # longest training demo has 373 steps
else:
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
# Start evaluation
total_episodes, total_successes = 0, 0
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
# Get task
task = task_suite.get_task(task_id)
# Get default LIBERO initial states
initial_states = task_suite.get_task_init_states(task_id)
# Initialize LIBERO environment and task description
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
# Start episodes
task_episodes, task_successes = 0, 0
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
logging.info(f"\nTask: {task_description}")
# Reset environment
env.reset()
action_plan = collections.deque()
# Set initial states
obs = env.set_init_state(initial_states[episode_idx])
# Setup
t = 0
replay_images = []
logging.info(f"Starting episode {task_episodes+1}...")
while t < max_steps + args.num_steps_wait:
try:
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
# and we need to wait for them to fall
if t < args.num_steps_wait:
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
t += 1
continue
# Get preprocessed image
# IMPORTANT: rotate 180 degrees to match train preprocessing
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
)
wrist_img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
)
# Save preprocessed image for replay video
replay_images.append(img)
if not action_plan:
# Finished executing previous action chunk -- compute new chunk
# Prepare observations dict
element = {
"observation/image": img,
"observation/wrist_image": wrist_img,
"observation/state": np.concatenate(
(
obs["robot0_eef_pos"],
_quat2axisangle(obs["robot0_eef_quat"]),
obs["robot0_gripper_qpos"],
)
),
"prompt": str(task_description),
}
# Query model to get action
action_chunk = client.infer(element)["actions"]
assert (
len(action_chunk) >= args.replan_steps
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
action_plan.extend(action_chunk[: args.replan_steps])
action = action_plan.popleft()
# Execute action in environment
obs, reward, done, info = env.step(action.tolist())
if done:
task_successes += 1
total_successes += 1
break
t += 1
except Exception as e:
logging.error(f"Caught exception: {e}")
break
task_episodes += 1
total_episodes += 1
# Save a replay video of the episode
suffix = "success" if done else "failure"
task_segment = task_description.replace(" ", "_")
imageio.mimwrite(
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
[np.asarray(x) for x in replay_images],
fps=10,
)
# Log current results
logging.info(f"Success: {done}")
logging.info(f"# episodes completed so far: {total_episodes}")
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
# Log final results
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total episodes: {total_episodes}")
def _get_libero_env(task, resolution, seed):
"""Initializes and returns the LIBERO environment, along with the task description."""
task_description = task.language
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
env = OffScreenRenderEnv(**env_args)
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
return env, task_description
def _quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tyro.cli(eval_libero)

View File

@@ -0,0 +1,11 @@
imageio[ffmpeg]
numpy==1.22.4
tqdm
tyro
PyYaml
opencv-python==4.6.0.66
torch==1.11.0+cu113
torchvision==0.12.0+cu113
torchaudio==0.11.0+cu113
robosuite==1.4.1
matplotlib==3.5.3

View File

@@ -0,0 +1,136 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
absl-py==2.1.0
# via mujoco
certifi==2024.12.14
# via requests
charset-normalizer==3.4.0
# via requests
cycler==0.12.1
# via matplotlib
docstring-parser==0.16
# via tyro
etils==1.3.0
# via mujoco
eval-type-backport==0.2.0
# via tyro
evdev==1.7.1
# via pynput
fonttools==4.55.3
# via matplotlib
glfw==1.12.0
# via mujoco
idna==3.10
# via requests
imageio==2.35.1
# via -r examples/libero/requirements.in
imageio-ffmpeg==0.5.1
# via imageio
importlib-metadata==8.5.0
# via typeguard
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
llvmlite==0.36.0
# via numba
markdown-it-py==3.0.0
# via rich
matplotlib==3.5.3
# via -r examples/libero/requirements.in
mdurl==0.1.2
# via markdown-it-py
mujoco==3.2.3
# via robosuite
numba==0.53.1
# via robosuite
numpy==1.22.4
# via
# -r examples/libero/requirements.in
# imageio
# matplotlib
# mujoco
# numba
# opencv-python
# robosuite
# scipy
# torchvision
opencv-python==4.6.0.66
# via
# -r examples/libero/requirements.in
# robosuite
packaging==24.2
# via matplotlib
pillow==10.4.0
# via
# imageio
# matplotlib
# robosuite
# torchvision
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pynput==1.7.7
# via robosuite
pyopengl==3.1.7
# via mujoco
pyparsing==3.1.4
# via matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
python-xlib==0.33
# via pynput
pyyaml==6.0.2
# via -r examples/libero/requirements.in
requests==2.32.3
# via torchvision
rich==13.9.4
# via tyro
robosuite==1.4.1
# via -r examples/libero/requirements.in
scipy==1.10.1
# via robosuite
setuptools==75.3.0
# via
# imageio-ffmpeg
# numba
shtab==1.7.1
# via tyro
six==1.17.0
# via
# pynput
# python-dateutil
# python-xlib
termcolor==2.4.0
# via robosuite
torch==1.11.0+cu113
# via
# -r examples/libero/requirements.in
# torchaudio
# torchvision
torchaudio==0.11.0+cu113
# via -r examples/libero/requirements.in
torchvision==0.12.0+cu113
# via -r examples/libero/requirements.in
tqdm==4.67.1
# via -r examples/libero/requirements.in
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# torch
# torchvision
# typeguard
# tyro
tyro==0.9.2
# via -r examples/libero/requirements.in
urllib3==2.2.3
# via requests
zipp==3.20.2
# via
# etils
# importlib-metadata
# importlib-resources

View File

@@ -0,0 +1,134 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pathlib\n",
"\n",
"import numpy as np\n",
"\n",
"record_path = pathlib.Path(\"../policy_records\")\n",
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
"\n",
"records = []\n",
"for i in range(num_steps):\n",
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
" records.append(record)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"length of records\", len(records))\n",
"print(\"keys in records\", records[0].keys())\n",
"\n",
"for k in records[0]:\n",
" print(f\"{k} shape: {records[0][k].shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"\n",
"\n",
"def get_image(step: int, idx: int = 0):\n",
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
" return img[idx].transpose(1, 2, 0)\n",
"\n",
"\n",
"def show_image(step: int, idx_lst: list[int]):\n",
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
" return Image.fromarray(np.hstack(imgs))\n",
"\n",
"\n",
"for i in range(2):\n",
" display(show_image(i, [0]))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def get_axis(name, axis):\n",
" return np.array([record[name][axis] for record in records])\n",
"\n",
"\n",
"# qpos is [..., 14] of type float:\n",
"# 0-5: left arm joint angles\n",
"# 6: left arm gripper\n",
"# 7-12: right arm joint angles\n",
"# 13: right arm gripper\n",
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
"\n",
"\n",
"def make_data():\n",
" cur_dim = 0\n",
" in_data = {}\n",
" out_data = {}\n",
" for name, dim_size in names:\n",
" for i in range(dim_size):\n",
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
" cur_dim += 1\n",
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
"\n",
"\n",
"in_data, out_data = make_data()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in in_data.columns:\n",
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
" data.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,32 @@
# Dockerfile for the simple client.
# Build the container:
# docker build . -t simple_client -f examples/simple_client/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
FROM python:3.7-slim
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"

View File

@@ -0,0 +1,30 @@
# Simple Client
A minimal client that sends observations to the server and prints the inference rate.
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
```bash
uv run examples/simple_client/main.py --help
```
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/simple_client/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
uv run examples/simple_client/main.py --env DROID
```
Terminal window 2:
```bash
uv run scripts/serve_policy.py --env DROID
```

View File

@@ -0,0 +1,42 @@
# Run with:
# docker compose -f examples/simple_client/compose.yml up --build
services:
runtime:
image: simple_client
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/simple_client/Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,187 @@
import dataclasses
import enum
import logging
import pathlib
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import polars as pl
import rich
import tqdm
import tyro
logger = logging.getLogger(__name__)
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
@dataclasses.dataclass
class Args:
"""Command line arguments."""
# Host and port to connect to the server.
host: str = "0.0.0.0"
# Port to connect to the server. If None, the server will use the default port.
port: int | None = 8000
# API key to use for the server.
api_key: str | None = None
# Number of steps to run the policy for.
num_steps: int = 20
# Path to save the timings to a parquet file. (e.g., timing.parquet)
timing_file: pathlib.Path | None = None
# Environment to run the policy in.
env: EnvMode = EnvMode.ALOHA_SIM
class TimingRecorder:
"""Records timing measurements for different keys."""
def __init__(self) -> None:
self._timings: dict[str, list[float]] = {}
def record(self, key: str, time_ms: float) -> None:
"""Record a timing measurement for the given key."""
if key not in self._timings:
self._timings[key] = []
self._timings[key].append(time_ms)
def get_stats(self, key: str) -> dict[str, float]:
"""Get statistics for the given key."""
times = self._timings[key]
return {
"mean": float(np.mean(times)),
"std": float(np.std(times)),
"p25": float(np.quantile(times, 0.25)),
"p50": float(np.quantile(times, 0.50)),
"p75": float(np.quantile(times, 0.75)),
"p90": float(np.quantile(times, 0.90)),
"p95": float(np.quantile(times, 0.95)),
"p99": float(np.quantile(times, 0.99)),
}
def print_all_stats(self) -> None:
"""Print statistics for all keys in a concise format."""
table = rich.table.Table(
title="[bold blue]Timing Statistics[/bold blue]",
show_header=True,
header_style="bold white",
border_style="blue",
title_justify="center",
)
# Add metric column with custom styling
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
# Add statistical columns with consistent styling
stat_columns = [
("Mean", "yellow", "mean"),
("Std", "yellow", "std"),
("P25", "magenta", "p25"),
("P50", "magenta", "p50"),
("P75", "magenta", "p75"),
("P90", "magenta", "p90"),
("P95", "magenta", "p95"),
("P99", "magenta", "p99"),
]
for name, style, _ in stat_columns:
table.add_column(name, justify="right", style=style, no_wrap=True)
# Add rows for each metric with formatted values
for key in sorted(self._timings.keys()):
stats = self.get_stats(key)
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
table.add_row(key, *values)
# Print with custom console settings
console = rich.console.Console(width=None, highlight=True)
console.print(table)
def write_parquet(self, path: pathlib.Path) -> None:
"""Save the timings to a parquet file."""
logger.info(f"Writing timings to {path}")
frame = pl.DataFrame(self._timings)
path.parent.mkdir(parents=True, exist_ok=True)
frame.write_parquet(path)
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
api_key=args.api_key,
)
logger.info(f"Server metadata: {policy.get_server_metadata()}")
# Send a few observations to make sure the model is loaded.
for _ in range(2):
policy.infer(obs_fn())
timing_recorder = TimingRecorder()
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
inference_start = time.time()
action = policy.infer(obs_fn())
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
for key, value in action.get("server_timing", {}).items():
timing_recorder.record(f"server_{key}", value)
for key, value in action.get("policy_timing", {}).items():
timing_recorder.record(f"policy_{key}", value)
timing_recorder.print_all_stats()
if args.timing_file is not None:
timing_recorder.write_parquet(args.timing_file)
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))

View File

@@ -0,0 +1,5 @@
numpy>=1.22.4,<2.0.0
rich
tqdm
tyro
polars

View File

@@ -0,0 +1,30 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
docstring-parser==0.16
# via tyro
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
numpy==1.26.4
# via -r examples/simple_client/requirements.in
polars==1.30.0
# via -r examples/simple_client/requirements.in
pygments==2.19.1
# via rich
rich==14.0.0
# via
# -r examples/simple_client/requirements.in
# tyro
shtab==1.7.2
# via tyro
tqdm==4.67.1
# via -r examples/simple_client/requirements.in
typeguard==4.4.2
# via tyro
typing-extensions==4.13.2
# via
# typeguard
# tyro
tyro==0.9.22
# via -r examples/simple_client/requirements.in

View File

@@ -0,0 +1,142 @@
# UR5 Example
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
```python
@dataclasses.dataclass(frozen=True)
class UR5Inputs(transforms.DataTransformFn):
model_type: _model.ModelType = _model.ModelType.PI0
def __call__(self, data: dict) -> dict:
# First, concatenate the joints and gripper into the state vector.
state = np.concatenate([data["joints"], data["gripper"]])
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference.
base_image = _parse_image(data["base_rgb"])
wrist_image = _parse_image(data["wrist_rgb"])
# Create inputs dict.
inputs = {
"state": state,
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# Since there is no right wrist, replace with zeros
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
# Since the "slot" for the right wrist is not used, this mask is set
# to False
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
},
}
if "actions" in data:
inputs["actions"] = data["actions"]
# Pass the prompt (aka language instruction) to the model.
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class UR5Outputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
return {"actions": np.asarray(data["actions"][:, :7])}
```
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
```python
@dataclasses.dataclass(frozen=True)
class LeRobotUR5DataConfig(DataConfigFactory):
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"base_rgb": "image",
"wrist_rgb": "wrist_image",
"joints": "joints",
"gripper": "gripper",
"prompt": "prompt",
}
)
]
)
# These transforms are the ones we wrote earlier.
data_transforms = _transforms.Group(
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
outputs=[UR5Outputs()],
)
# Convert absolute actions to delta actions.
# By convention, we do not convert the gripper action (7th dimension).
delta_action_mask = _transforms.make_bool_mask(6, -1)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
)
# Model transforms include things like tokenizing the prompt and action targets
# You do not need to change anything here for your own dataset.
model_transforms = ModelTransformFactory()(model_config)
# We return all data transforms for training and inference. No need to change anything here.
return dataclasses.replace(
self.create_base_config(assets_dirs),
repack_transforms=repack_transform,
data_transforms=data_transforms,
model_transforms=model_transforms,
)
```
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
```python
TrainConfig(
name="pi0_ur5",
model=pi0.Pi0Config(),
data=LeRobotUR5DataConfig(
repo_id="your_username/ur5_dataset",
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
# Reloading normalization stats can help transfer pre-trained models to new environments.
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="ur5e",
),
base_config=DataConfig(
# This flag determines whether we load the prompt (i.e. the task instruction) from the
# ``task`` field in the LeRobot dataset. The recommended setting is True.
prompt_from_task=True,
),
),
# Load the pi0 base model checkpoint.
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
)
```