Initial commit

This commit is contained in:
Ury Zhilinsky
2025-02-03 21:43:26 -08:00
commit 231a1cf7ca
121 changed files with 16349 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
# Dockerfile for the Aloha real environment.
# Build the container:
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
SHELL ["/bin/bash", "-c"]
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
cmake \
curl \
libffi-dev \
python3-rosdep \
python3-rosinstall \
python3-rosinstall-generator \
whiptail \
git \
wget \
openssh-client \
ros-noetic-cv-bridge \
ros-noetic-usb-cam \
ros-noetic-realsense2-camera \
keyboard-configuration
WORKDIR /root
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
RUN chmod +x xsarm_amd64_install.sh
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
# Install python 3.10 because this ROS image comes with 3.8
RUN mkdir /python && \
cd /python && \
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
tar -zxvf Python-3.10.14.tgz && \
cd Python-3.10.14 && \
ls -lhR && \
./configure --enable-optimizations && \
make install && \
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
cd ~ && rm -rf /python && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
ENV UV_HTTP_TIMEOUT=120
ENV UV_LINK_MODE=copy
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
WORKDIR /app
# Create an entrypoint script to run the setup commands, followed by the command passed in.
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
#!/bin/bash
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
EOF
RUN chmod +x /usr/local/bin/entrypoint.sh
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["python3", "/app/examples/aloha_real/main.py"]

View File

@@ -0,0 +1,126 @@
# Run Aloha (Real Robot)
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../../openpi/docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
## Prerequisites
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
## With Docker
```bash
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
docker compose -f examples/aloha_real/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_real/.venv
source examples/aloha_real/.venv/bin/activate
uv pip sync examples/aloha_real/requirements.txt
uv pip install -e packages/openpi-client
# Run the robot
python examples/aloha_real/main.py
```
Terminal window 2:
```bash
roslaunch --wait aloha ros_nodes.launch
```
Terminal window 3:
```bash
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
```
## **ALOHA Checkpoint Guide**
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
While weve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects weve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
---
### **Toast Task**
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_base`
- **Prompt**: "take the toast out of the toaster"
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
- **Object Distribution**:
- Works on both real toast and rubber fake toast
- Compatible with standard 2-slice toasters
- Works with plates of varying colors
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
- The toaster should be positioned in the top-left quadrant of the workspace.
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
- The plate should be placed roughly in the lower-center of the workspace.
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
### **Towel Task**
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_towel`
- **Prompt**: "fold the towel"
- **Object Distribution**:
- Works on towels of varying solid colors
- Performance is worse on heavily textured or striped towels
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
- The towel should be flattened and roughly centered on the table.
- Choose a towel that does not blend in with the table surface.
### **Tupperware Task**
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_tupperware`
- **Prompt**: "open the tupperware and put the food on the plate"
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
- **Object Distribution**:
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
- The policy has seen plates of varying solid colors.
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
- Positioning:
- Tupperware should be on the left.
- Plate should be on the right or bottom.
- The tupperware flap should point toward the plate.
## Training on your own Aloha dataset
1. Convert the dataset to the LeRobot dataset v2.0 format.
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
2. Define a training config that uses the custom dataset.
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoints asset directory within the AssetsConfig.

View File

@@ -0,0 +1,66 @@
# Run with:
# docker compose -f examples/aloha_real/compose.yml up --build
services:
runtime:
image: aloha_real
depends_on:
- aloha_ros_nodes
- ros_master
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
aloha_ros_nodes:
image: aloha_real
depends_on:
- ros_master
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- /dev:/dev
command: roslaunch --wait aloha ros_nodes.launch
ros_master:
image: ros:noetic-robot
network_mode: host
privileged: true
command:
- roscore
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,71 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
### Task parameters
### ALOHA fixed constants
DT = 0.001
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
############################ Helper functions ############################
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
MASTER_POS2JOINT = (
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
PUPPET_POS2JOINT = (
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2

View File

@@ -0,0 +1,272 @@
"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""
import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
import tyro
@dataclasses.dataclass(frozen=True)
class DatasetConfig:
use_videos: bool = True
tolerance_s: float = 0.0001
image_writer_processes: int = 10
image_writer_threads: int = 5
video_backend: str | None = None
DEFAULT_DATASET_CONFIG = DatasetConfig()
def create_empty_dataset(
repo_id: str,
robot_type: str,
mode: Literal["video", "image"] = "video",
*,
has_velocity: bool = False,
has_effort: bool = False,
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
motors = [
"right_waist",
"right_shoulder",
"right_elbow",
"right_forearm_roll",
"right_wrist_angle",
"right_wrist_rotate",
"right_gripper",
"left_waist",
"left_shoulder",
"left_elbow",
"left_forearm_roll",
"left_wrist_angle",
"left_wrist_rotate",
"left_gripper",
]
cameras = [
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
]
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
}
if has_velocity:
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
if has_effort:
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
for cam in cameras:
features[f"observation.images.{cam}"] = {
"dtype": mode,
"shape": (3, 480, 640),
"names": [
"channels",
"height",
"width",
],
}
if Path(LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
return LeRobotDataset.create(
repo_id=repo_id,
fps=50,
robot_type=robot_type,
features=features,
use_videos=dataset_config.use_videos,
tolerance_s=dataset_config.tolerance_s,
image_writer_processes=dataset_config.image_writer_processes,
image_writer_threads=dataset_config.image_writer_threads,
video_backend=dataset_config.video_backend,
)
def get_cameras(hdf5_files: list[Path]) -> list[str]:
with h5py.File(hdf5_files[0], "r") as ep:
# ignore depth channel, not currently handled
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
def has_velocity(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/qvel" in ep
def has_effort(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/effort" in ep
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
imgs_per_cam = {}
for camera in cameras:
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
if uncompressed:
# load all images in RAM
imgs_array = ep[f"/observations/images/{camera}"][:]
else:
import cv2
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
imgs_array.append(cv2.imdecode(data, 1))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array
return imgs_per_cam
def load_raw_episode_data(
ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
effort = None
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
imgs_per_cam = load_raw_images_per_camera(
ep,
[
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
],
)
return imgs_per_cam, state, action, velocity, effort
def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode(task=task)
return dataset
def port_aloha(
raw_dir: Path,
repo_id: str,
raw_repo_id: str | None = None,
task: str = "DEBUG",
*,
episodes: list[int] | None = None,
push_to_hub: bool = True,
is_mobile: bool = False,
mode: Literal["video", "image"] = "image",
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
if raw_repo_id is None:
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
download_raw(raw_dir, repo_id=raw_repo_id)
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if is_mobile else "aloha",
mode=mode,
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
dataset_config=dataset_config,
)
dataset = populate_dataset(
dataset,
hdf5_files,
task=task,
episodes=episodes,
)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
tyro.cli(port_aloha)

View File

@@ -0,0 +1,57 @@
from typing import List, Optional # noqa: UP035
import einops
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override
from examples.aloha_real import real_env as _real_env
class AlohaRealEnvironment(_environment.Environment):
"""An environment for an Aloha robot on real hardware."""
def __init__(
self,
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
render_height: int = 224,
render_width: int = 224,
) -> None:
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
self._render_height = render_height
self._render_width = render_width
self._ts = None
@override
def reset(self) -> None:
self._ts = self._env.reset()
@override
def is_episode_complete(self) -> bool:
return False
@override
def get_observation(self) -> dict:
if self._ts is None:
raise RuntimeError("Timestep is not set. Call reset() first.")
obs = self._ts.observation
for k in list(obs["images"].keys()):
if "_depth" in k:
del obs["images"][k]
for cam_name in obs["images"]:
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
)
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
return {
"state": obs["qpos"],
"images": obs["images"],
}
@override
def apply_action(self, action: dict) -> None:
self._ts = self._env.step(action["actions"])

View File

@@ -0,0 +1,51 @@
import dataclasses
import logging
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import tyro
from examples.aloha_real import env as _env
@dataclasses.dataclass
class Args:
host: str = "0.0.0.0"
port: int = 8000
action_horizon: int = 25
num_episodes: int = 1
max_episode_steps: int = 1000
def main(args: Args) -> None:
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
metadata = ws_client_policy.get_server_metadata()
runtime = _runtime.Runtime(
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=ws_client_policy,
action_horizon=args.action_horizon,
)
),
subscribers=[],
max_hz=50,
num_episodes=args.num_episodes,
max_episode_steps=args.max_episode_steps,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,171 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
import collections
import time
from typing import Optional, List
import dm_env
from interbotix_xs_modules.arm import InterbotixManipulatorXS
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
from examples.aloha_real import constants
from examples.aloha_real import robot_utils
# This is the reset position that is used by the standard Aloha runtime.
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
class RealEnv:
"""
Environment for real robot bi-manual manipulation
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
"""
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
# reset_position = START_ARM_POSE[:6]
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
self.puppet_bot_left = InterbotixManipulatorXS(
robot_model="vx300s",
group_name="arm",
gripper_name="gripper",
robot_name="puppet_left",
init_node=init_node,
)
self.puppet_bot_right = InterbotixManipulatorXS(
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
)
if setup_robots:
self.setup_robots()
self.recorder_left = robot_utils.Recorder("left", init_node=False)
self.recorder_right = robot_utils.Recorder("right", init_node=False)
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
self.gripper_command = JointSingleCommand(name="gripper")
def setup_robots(self):
robot_utils.setup_puppet_bot(self.puppet_bot_left)
robot_utils.setup_puppet_bot(self.puppet_bot_right)
def get_qpos(self):
left_qpos_raw = self.recorder_left.qpos
right_qpos_raw = self.recorder_right.qpos
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
] # this is position not joint
right_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
] # this is position not joint
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
def get_qvel(self):
left_qvel_raw = self.recorder_left.qvel
right_qvel_raw = self.recorder_right.qvel
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
def get_effort(self):
left_effort_raw = self.recorder_left.effort
right_effort_raw = self.recorder_right.effort
left_robot_effort = left_effort_raw[:7]
right_robot_effort = right_effort_raw[:7]
return np.concatenate([left_robot_effort, right_robot_effort])
def get_images(self):
return self.image_recorder.get_images()
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
self.gripper_command.cmd = left_gripper_desired_joint
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
right_gripper_desired_pos_normalized
)
self.gripper_command.cmd = right_gripper_desired_joint
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
def _reset_joints(self):
robot_utils.move_arms(
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
)
def _reset_gripper(self):
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
)
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
)
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
return obs
def get_reward(self):
return 0
def reset(self, *, fake=False):
if not fake:
# Reboot puppet robot gripper motors
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
self._reset_joints()
self._reset_gripper()
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def step(self, action):
state_len = int(len(action) / 2)
left_action = action[:state_len]
right_action = action[state_len:]
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
self.set_gripper_pose(left_action[-1], right_action[-1])
time.sleep(constants.DT)
return dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def get_action(master_bot_left, master_bot_right):
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
# Arm actions
action[:6] = master_bot_left.dxl.joint_states.position[:6]
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
# Gripper actions
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
return action
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)

View File

@@ -0,0 +1,18 @@
Pillow
dm_control
einops
h5py
matplotlib
modern_robotics
msgpack
numpy
opencv-python
packaging
pexpect
pyquaternion
pyrealsense2
pyyaml
requests
rospkg
tyro
websockets

View File

@@ -0,0 +1,156 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
catkin-pkg==1.0.0
# via rospkg
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
contourpy==1.1.1
# via matplotlib
cycler==0.12.1
# via matplotlib
distro==1.9.0
# via rospkg
dm-control==1.0.23
# via -r examples/aloha_real/requirements.in
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
docutils==0.20.1
# via catkin-pkg
einops==0.8.0
# via -r examples/aloha_real/requirements.in
etils==1.3.0
# via mujoco
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
h5py==3.11.0
# via -r examples/aloha_real/requirements.in
idna==3.10
# via requests
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.7.5
# via -r examples/aloha_real/requirements.in
mdurl==0.1.2
# via markdown-it-py
modern-robotics==1.1.1
# via -r examples/aloha_real/requirements.in
msgpack==1.1.0
# via -r examples/aloha_real/requirements.in
mujoco==3.2.3
# via dm-control
numpy==1.24.4
# via
# -r examples/aloha_real/requirements.in
# contourpy
# dm-control
# dm-env
# h5py
# labmaze
# matplotlib
# modern-robotics
# mujoco
# opencv-python
# pyquaternion
# scipy
opencv-python==4.10.0.84
# via -r examples/aloha_real/requirements.in
packaging==24.2
# via
# -r examples/aloha_real/requirements.in
# matplotlib
pexpect==4.9.0
# via -r examples/aloha_real/requirements.in
pillow==10.4.0
# via
# -r examples/aloha_real/requirements.in
# matplotlib
protobuf==5.29.1
# via dm-control
ptyprocess==0.7.0
# via pexpect
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.1.4
# via
# catkin-pkg
# dm-control
# matplotlib
pyquaternion==0.9.9
# via -r examples/aloha_real/requirements.in
pyrealsense2==2.55.1.6486
# via -r examples/aloha_real/requirements.in
python-dateutil==2.9.0.post0
# via
# catkin-pkg
# matplotlib
pyyaml==6.0.2
# via
# -r examples/aloha_real/requirements.in
# rospkg
requests==2.32.3
# via
# -r examples/aloha_real/requirements.in
# dm-control
rich==13.9.4
# via tyro
rospkg==1.5.1
# via -r examples/aloha_real/requirements.in
scipy==1.10.1
# via dm-control
setuptools==75.3.0
# via
# catkin-pkg
# dm-control
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_real/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_real/requirements.in
zipp==3.20.2
# via etils

View File

@@ -0,0 +1,275 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
from collections import deque
import datetime
import json
import time
from aloha.msg import RGBGrayscaleImage
from cv_bridge import CvBridge
from interbotix_xs_msgs.msg import JointGroupCommand
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
import rospy
from sensor_msgs.msg import JointState
from examples.aloha_real import constants
class ImageRecorder:
def __init__(self, init_node=True, is_debug=False):
self.is_debug = is_debug
self.bridge = CvBridge()
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
if init_node:
rospy.init_node("image_recorder", anonymous=True)
for cam_name in self.camera_names:
setattr(self, f"{cam_name}_rgb_image", None)
setattr(self, f"{cam_name}_depth_image", None)
setattr(self, f"{cam_name}_timestamp", 0.0)
if cam_name == "cam_high":
callback_func = self.image_cb_cam_high
elif cam_name == "cam_low":
callback_func = self.image_cb_cam_low
elif cam_name == "cam_left_wrist":
callback_func = self.image_cb_cam_left_wrist
elif cam_name == "cam_right_wrist":
callback_func = self.image_cb_cam_right_wrist
else:
raise NotImplementedError
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
if self.is_debug:
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
time.sleep(0.5)
def image_cb(self, cam_name, data):
setattr(
self,
f"{cam_name}_rgb_image",
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
)
# setattr(
# self,
# f"{cam_name}_depth_image",
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
# )
setattr(
self,
f"{cam_name}_timestamp",
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
)
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
if self.is_debug:
getattr(self, f"{cam_name}_timestamps").append(
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
)
def image_cb_cam_high(self, data):
cam_name = "cam_high"
return self.image_cb(cam_name, data)
def image_cb_cam_low(self, data):
cam_name = "cam_low"
return self.image_cb(cam_name, data)
def image_cb_cam_left_wrist(self, data):
cam_name = "cam_left_wrist"
return self.image_cb(cam_name, data)
def image_cb_cam_right_wrist(self, data):
cam_name = "cam_right_wrist"
return self.image_cb(cam_name, data)
def get_images(self):
image_dict = {}
for cam_name in self.camera_names:
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
time.sleep(0.00001)
rgb_image = getattr(self, f"{cam_name}_rgb_image")
depth_image = getattr(self, f"{cam_name}_depth_image")
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
image_dict[cam_name] = rgb_image
image_dict[f"{cam_name}_depth"] = depth_image
return image_dict
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
for cam_name in self.camera_names:
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
print(f"{cam_name} {image_freq=:.2f}")
print()
class Recorder:
def __init__(self, side, init_node=True, is_debug=False):
self.secs = None
self.nsecs = None
self.qpos = None
self.effort = None
self.arm_command = None
self.gripper_command = None
self.is_debug = is_debug
if init_node:
rospy.init_node("recorder", anonymous=True)
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_group",
JointGroupCommand,
self.puppet_arm_commands_cb,
)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_single",
JointSingleCommand,
self.puppet_gripper_commands_cb,
)
if self.is_debug:
self.joint_timestamps = deque(maxlen=50)
self.arm_command_timestamps = deque(maxlen=50)
self.gripper_command_timestamps = deque(maxlen=50)
time.sleep(0.1)
def puppet_state_cb(self, data):
self.qpos = data.position
self.qvel = data.velocity
self.effort = data.effort
self.data = data
if self.is_debug:
self.joint_timestamps.append(time.time())
def puppet_arm_commands_cb(self, data):
self.arm_command = data.cmd
if self.is_debug:
self.arm_command_timestamps.append(time.time())
def puppet_gripper_commands_cb(self, data):
self.gripper_command = data.cmd
if self.is_debug:
self.gripper_command_timestamps.append(time.time())
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
joint_freq = 1 / dt_helper(self.joint_timestamps)
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
def get_arm_joint_positions(bot):
return bot.arm.core.joint_states.position[:6]
def get_arm_gripper_positions(bot):
return bot.gripper.core.joint_states.position[6]
def move_arms(bot_list, target_pose_list, move_time=1):
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
for t in range(num_steps):
for bot_id, bot in enumerate(bot_list):
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
time.sleep(constants.DT)
def move_grippers(bot_list, target_pose_list, move_time):
print(f"Moving grippers to {target_pose_list=}")
gripper_command = JointSingleCommand(name="gripper")
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
for t in range(num_steps):
d = {}
for bot_id, bot in enumerate(bot_list):
gripper_command.cmd = traj_list[bot_id][t]
bot.gripper.core.pub_single.publish(gripper_command)
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
f.write(json.dumps(d) + "\n")
time.sleep(constants.DT)
def setup_puppet_bot(bot):
bot.dxl.robot_reboot_motors("single", "gripper", True)
bot.dxl.robot_set_operating_modes("group", "arm", "position")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_on(bot)
def setup_master_bot(bot):
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_off(bot)
def set_standard_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def set_low_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def torque_off(bot):
bot.dxl.robot_torque_enable("group", "arm", False)
bot.dxl.robot_torque_enable("single", "gripper", False)
def torque_on(bot):
bot.dxl.robot_torque_enable("group", "arm", True)
bot.dxl.robot_torque_enable("single", "gripper", True)
# for DAgger
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
print("\nSyncing!")
# activate master arms
torque_on(master_bot_left)
torque_on(master_bot_right)
# get puppet arm positions
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
# get puppet gripper positions
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
# move master arms to puppet positions
move_arms(
[master_bot_left, master_bot_right],
[puppet_left_qpos, puppet_right_qpos],
move_time=1,
)
# move master grippers to puppet positions
move_grippers(
[master_bot_left, master_bot_right],
[puppet_left_gripper, puppet_right_gripper],
move_time=1,
)

View File

@@ -0,0 +1,36 @@
import matplotlib.pyplot as plt
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoDisplay(_subscriber.Subscriber):
"""Displays video frames."""
def __init__(self) -> None:
self._ax: plt.Axes | None = None
self._plt_img: plt.Image | None = None
@override
def on_episode_start(self) -> None:
plt.ion()
self._ax = plt.subplot()
self._plt_img = None
@override
def on_step(self, observation: dict, action: dict) -> None:
assert self._ax is not None
im = observation["image"][0] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
if self._plt_img is None:
self._plt_img = self._ax.imshow(im)
else:
self._plt_img.set_data(im)
plt.pause(0.001)
@override
def on_episode_end(self) -> None:
plt.ioff()
plt.close()

View File

@@ -0,0 +1,41 @@
# Dockerfile for the Aloha simulation environment.
# Build the container:
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev
ENV MUJOCO_GL=egl
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]

View File

@@ -0,0 +1,36 @@
# Run Aloha Sim
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/aloha_sim/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_sim/.venv
source examples/aloha_sim/.venv/bin/activate
uv pip sync examples/aloha_sim/requirements.txt
uv pip install -e packages/openpi-client
# Run the simulation
MUJOCO_GL=egl python examples/aloha_sim/main.py
```
Note: If you are seeing EGL errors, you may need to install the following dependencies:
```bash
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env ALOHA_SIM
```

View File

@@ -0,0 +1,42 @@
# Run with:
# docker compose -f examples/aloha_sim/compose.yml up --build
services:
runtime:
image: aloha_sim
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_sim/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

56
examples/aloha_sim/env.py Normal file
View File

@@ -0,0 +1,56 @@
import gym_aloha # noqa: F401
import gymnasium
import numpy as np
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override
class AlohaSimEnvironment(_environment.Environment):
"""An environment for an Aloha robot in simulation."""
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
np.random.seed(seed)
self._rng = np.random.default_rng(seed)
self._gym = gymnasium.make(task, obs_type=obs_type)
self._last_obs = None
self._done = True
self._episode_reward = 0.0
@override
def reset(self) -> None:
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = False
self._episode_reward = 0.0
@override
def is_episode_complete(self) -> bool:
return self._done
@override
def get_observation(self) -> dict:
if self._last_obs is None:
raise RuntimeError("Observation is not set. Call reset() first.")
return self._last_obs # type: ignore
@override
def apply_action(self, action: dict) -> None:
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = terminated or truncated
self._episode_reward = max(self._episode_reward, reward)
def _convert_observation(self, gym_obs: dict) -> dict:
img = gym_obs["pixels"]["top"]
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
# Convert axis order from [H, W, C] --> [C, H, W]
img = np.transpose(img, (2, 0, 1))
return {
"state": gym_obs["agent_pos"],
"images": {"cam_high": img},
}

View File

@@ -0,0 +1,55 @@
import dataclasses
import logging
import pathlib
import env as _env
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import saver as _saver
import tyro
@dataclasses.dataclass
class Args:
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
task: str = "gym_aloha/AlohaTransferCube-v0"
seed: int = 0
action_horizon: int = 10
host: str = "0.0.0.0"
port: int = 8000
display: bool = False
def main(args: Args) -> None:
runtime = _runtime.Runtime(
environment=_env.AlohaSimEnvironment(
task=args.task,
seed=args.seed,
),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=_websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
),
action_horizon=args.action_horizon,
)
),
subscribers=[
_saver.VideoSaver(args.out_dir),
],
max_hz=50,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,8 @@
gym-aloha
imageio
matplotlib
msgpack
numpy
typing-extensions
tyro
websockets

View File

@@ -0,0 +1,132 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
cloudpickle==3.1.0
# via gymnasium
contourpy==1.3.1
# via matplotlib
cycler==0.12.1
# via matplotlib
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
farama-notifications==0.0.4
# via gymnasium
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
gym-aloha==0.1.1
# via -r examples/aloha_sim/requirements.in
gymnasium==1.0.0
# via gym-aloha
idna==3.10
# via requests
imageio==2.36.1
# via
# -r examples/aloha_sim/requirements.in
# gym-aloha
imageio-ffmpeg==0.5.1
# via imageio
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.9.3
# via -r examples/aloha_sim/requirements.in
mdurl==0.1.2
# via markdown-it-py
msgpack==1.1.0
# via -r examples/aloha_sim/requirements.in
mujoco==2.3.7
# via
# dm-control
# gym-aloha
numpy==1.26.4
# via
# -r examples/aloha_sim/requirements.in
# contourpy
# dm-control
# dm-env
# gymnasium
# imageio
# labmaze
# matplotlib
# mujoco
# scipy
packaging==24.2
# via matplotlib
pillow==11.0.0
# via
# imageio
# matplotlib
protobuf==5.29.1
# via dm-control
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.2.0
# via
# dm-control
# matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
requests==2.32.3
# via dm-control
rich==13.9.4
# via tyro
scipy==1.14.1
# via dm-control
setuptools==75.6.0
# via
# dm-control
# imageio-ffmpeg
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.1
# via tyro
typing-extensions==4.12.2
# via
# -r examples/aloha_sim/requirements.in
# gymnasium
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_sim/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_sim/requirements.in

View File

@@ -0,0 +1,40 @@
import logging
import pathlib
import imageio
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoSaver(_subscriber.Subscriber):
"""Saves episode data."""
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
out_dir.mkdir(parents=True, exist_ok=True)
self._out_dir = out_dir
self._images: list[np.ndarray] = []
self._subsample = subsample
@override
def on_episode_start(self) -> None:
self._images = []
@override
def on_step(self, observation: dict, action: dict) -> None:
im = observation["images"]["cam_high"] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
self._images.append(im)
@override
def on_episode_end(self) -> None:
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
out_path = self._out_dir / f"out_{next_idx}.mp4"
logging.info(f"Saving video to {out_path}")
imageio.mimwrite(
out_path,
[np.asarray(x) for x in self._images[:: self._subsample]],
fps=50 // max(1, self._subsample),
)

46
examples/droid/README.md Normal file
View File

@@ -0,0 +1,46 @@
# Run DROID
This example shows how to run the fine-tuned $\pi_0$-FAST-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). We also offer a $\pi_0$-DROID model that is fine-tuned from $\pi_0$ and uses flow action decoding. You can use it by replacing `pi0_fast_droid` with `pi0_droid` in the commands below. In practice, we find that out-of-the-box, the $\pi_0$-FAST-DROID model is better at following language commands, so we recommend it as the default checkpoint for DROID evaluation. If you want to fine-tune on a DROID task that requires a fast-to-inference policy, you may still want to consider using the $\pi_0$-DROID model, since it decodes faster. For more details, please see the [FAST paper](https://pi.website/research/fast).
## Step 1: Start a policy server
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
2. Start the OpenPI server via the following command:
```bash
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid
```
You can also run the equivalent command below:
```bash
uv run scripts/serve_policy.py --env=DROID
```
## Step 2: Run the DROID robot
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
2. On the control laptop, activate your DROID conda environment.
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
```bash
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
```
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
# Troubleshooting
| Issue | Solution |
|-------|----------|
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |

237
examples/droid/main.py Normal file
View File

@@ -0,0 +1,237 @@
# ruff: noqa
import contextlib
import dataclasses
import datetime
import faulthandler
import os
import signal
from moviepy.editor import ImageSequenceClip
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy
import pandas as pd
from PIL import Image
from droid.robot_env import RobotEnv
import tqdm
import tyro
faulthandler.enable()
@dataclasses.dataclass
class Args:
# Hardware parameters
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
# Policy parameters
external_camera: str | None = (
None # which external camera should be fed to the policy, choose from ["left", "right"]
)
# Rollout parameters
max_timesteps: int = 600
# How many actions to execute from a predicted action chunk before querying policy server again
# 8 is usually a good default (equals 0.5 seconds of action execution).
open_loop_horizon: int = 8
# Remote server parameters
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
remote_port: int = (
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
)
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
# waiting for a new action chunk, it will raise an exception and the server connection dies.
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
@contextlib.contextmanager
def prevent_keyboard_interrupt():
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
interrupted = False
original_handler = signal.getsignal(signal.SIGINT)
def handler(signum, frame):
nonlocal interrupted
interrupted = True
signal.signal(signal.SIGINT, handler)
try:
yield
finally:
signal.signal(signal.SIGINT, original_handler)
if interrupted:
raise KeyboardInterrupt
def main(args: Args):
# Make sure external camera is specified by user -- we only use one external camera for the policy
assert (
args.external_camera is not None and args.external_camera in ["left", "right"]
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
print("Created the droid env!")
# Connect to the policy server
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
while True:
instruction = input("Enter instruction: ")
# Rollout parameters
actions_from_chunk_completed = 0
pred_action_chunk = None
# Prepare to save video of rollout
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
video = []
bar = tqdm.tqdm(range(args.max_timesteps))
print("Running rollout... press Ctrl+C to stop early.")
for t_step in bar:
try:
# Get the current observation
curr_obs = _extract_observation(
args,
env.get_observation(),
# Save the first observation to disk
save_to_disk=t_step == 0,
)
video.append(curr_obs[f"{args.external_camera}_image"])
# Send websocket request to policy server if it's time to predict a new chunk
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
actions_from_chunk_completed = 0
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
# and improve latency.
request_data = {
"observation/exterior_image_1_left": image_tools.resize_with_pad(
curr_obs[f"{args.external_camera}_image"], 224, 224
),
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
"observation/joint_position": curr_obs["joint_position"],
"observation/gripper_position": curr_obs["gripper_position"],
"prompt": instruction,
}
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
# Ctrl+C will be handled after the server call is complete
with prevent_keyboard_interrupt():
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
pred_action_chunk = policy_client.infer(request_data)["actions"]
assert pred_action_chunk.shape == (10, 8)
# Select current action to execute from chunk
action = pred_action_chunk[actions_from_chunk_completed]
actions_from_chunk_completed += 1
# Binarize gripper action
if action[-1].item() > 0.5:
# action[-1] = 1.0
action = np.concatenate([action[:-1], np.ones((1,))])
else:
# action[-1] = 0.0
action = np.concatenate([action[:-1], np.zeros((1,))])
# clip all dimensions of action to [-1, 1]
action = np.clip(action, -1, 1)
env.step(action)
except KeyboardInterrupt:
break
video = np.stack(video)
save_filename = "video_" + timestamp
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
success: str | float | None = None
while not isinstance(success, float):
success = input(
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
)
if success == "y":
success = 1.0
elif success == "n":
success = 0.0
success = float(success) / 100
if not (0 <= success <= 1):
print(f"Success must be a number in [0, 100] but got: {success * 100}")
df = df.append(
{
"success": success,
"duration": t_step,
"video_filename": save_filename,
},
ignore_index=True,
)
if input("Do one more eval? (enter y or n) ").lower() != "y":
break
env.reset()
os.makedirs("results", exist_ok=True)
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
df.to_csv(csv_filename)
print(f"Results saved to {csv_filename}")
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
image_observations = obs_dict["image"]
left_image, right_image, wrist_image = None, None, None
for key in image_observations:
# Note the "left" below refers to the left camera in the stereo pair.
# The model is only trained on left stereo cams, so we only feed those.
if args.left_camera_id in key and "left" in key:
left_image = image_observations[key]
elif args.right_camera_id in key and "left" in key:
right_image = image_observations[key]
elif args.wrist_camera_id in key and "left" in key:
wrist_image = image_observations[key]
# Drop the alpha dimension
left_image = left_image[..., :3]
right_image = right_image[..., :3]
wrist_image = wrist_image[..., :3]
# Convert to RGB
left_image = left_image[..., ::-1]
right_image = right_image[..., ::-1]
wrist_image = wrist_image[..., ::-1]
# In addition to image observations, also capture the proprioceptive state
robot_state = obs_dict["robot_state"]
cartesian_position = np.array(robot_state["cartesian_position"])
joint_position = np.array(robot_state["joint_positions"])
gripper_position = np.array([robot_state["gripper_position"]])
# Save the images to disk so that they can be viewed live while the robot is running
# Create one combined image to make live viewing easy
if save_to_disk:
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
combined_image = Image.fromarray(combined_image)
combined_image.save("robot_camera_views.png")
return {
"left_image": left_image,
"right_image": right_image,
"wrist_image": wrist_image,
"cartesian_position": cartesian_position,
"joint_position": joint_position,
"gripper_position": gripper_position,
}
if __name__ == "__main__":
args: Args = tyro.cli(Args)
main(args)

137
examples/inference.ipynb Normal file
View File

@@ -0,0 +1,137 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import dataclasses\n",
"\n",
"import jax\n",
"\n",
"from openpi.models import model as _model\n",
"from openpi.policies import droid_policy\n",
"from openpi.policies import policy_config as _policy_config\n",
"from openpi.shared import download\n",
"from openpi.training import config as _config\n",
"from openpi.training import data_loader as _data_loader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Policy inference\n",
"\n",
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_fast_droid\")\n",
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
"\n",
"# Create a trained policy.\n",
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
"\n",
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
"example = droid_policy.make_droid_example()\n",
"result = policy.infer(example)\n",
"\n",
"# Delete the policy to free up memory.\n",
"del policy\n",
"\n",
"print(\"Actions shape:\", result[\"actions\"].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Working with a live model\n",
"\n",
"\n",
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_aloha_sim\")\n",
"\n",
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
"key = jax.random.key(0)\n",
"\n",
"# Create a model from the checkpoint.\n",
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
"\n",
"# We can create fake observations and actions to test the model.\n",
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"print(\"Loss shape:\", loss.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Reduce the batch size to reduce memory usage.\n",
"config = dataclasses.replace(config, batch_size=2)\n",
"\n",
"# Load a single batch of data. This is the same data that will be used during training.\n",
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
"obs, act = next(iter(loader))\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"\n",
"# Delete the model to free up memory.\n",
"del model\n",
"\n",
"print(\"Loss shape:\", loss.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,59 @@
# Dockerfile for the LIBERO benchmark.
# Build the container:
# docker build . -t libero -f examples/libero/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
make \
g++ \
clang \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
# Create a default config file to avoid an input prompt from LIBERO's init script.
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
ENV LIBERO_CONFIG_PATH=/tmp/libero
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
benchmark_root: /app/third_party/libero/libero/libero
bddl_files: /app/third_party/libero/libero/libero/bddl_files
init_states: /app/third_party/libero/libero/libero/init_files
datasets: /app/third_party/libero/libero/datasets
assets: /app/third_party/libero/libero/libero/assets
EOF
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"]

56
examples/libero/README.md Normal file
View File

@@ -0,0 +1,56 @@
# LIBERO Benchmark
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
This example requires git submodules to be initialized. Don't forget to run:
```bash
git submodule update --init --recursive
```
## With Docker
```bash
# Grant access to the X11 server:
sudo xhost +local:docker
export SERVER_ARGS="--env LIBERO"
docker compose -f examples/libero/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.8 examples/libero/.venv
source examples/libero/.venv/bin/activate
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
uv pip install -e packages/openpi-client
uv pip install -e third_party/libero
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
# Run the simulation
python examples/libero/main.py
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env LIBERO
```
## Results
If you follow the training instructions and hyperparameters in the `pi0_libero` and `pi0_fast_libero` configs, you should get results similar to the following:
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|-------|---------------|---------------|-------------|-----------|---------|
| π0-FAST @ 30k (finetuned) | 96.4 | 96.8 | 88.6 | 60.2 | 85.5 |
| π0 @ 30k (finetuned) | 96.8 | 98.8 | 95.8 | 85.2 | 94.15 |
Note that the hyperparameters for these runs are not tuned and $\pi_0$-FAST does not use a FAST tokenizer optimized for Libero. Likely, the results could be improved with more tuning, we mainly use these results as an example of how to use openpi to fine-tune $\pi_0$ models on a new dataset.

View File

@@ -0,0 +1,52 @@
# Run with:
# docker compose -f examples/libero/compose.yml up --build
services:
runtime:
image: libero
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/libero/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
- /tmp/.X11-unix:/tmp/.X11-unix:ro
environment:
- DISPLAY=$DISPLAY
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,106 @@
"""
Minimal example script for converting a dataset to LeRobot format.
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
modified for any other data you have saved in a custom format.
Usage:
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
Note: to run the script, you need to install tensorflow_datasets:
`uv pip install tensorflow tensorflow_datasets`
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
The resulting dataset will get saved to the $LEROBOT_HOME directory.
Running this conversion script will take approximately 30 minutes.
"""
import shutil
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import tensorflow_datasets as tfds
import tyro
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
RAW_DATASET_NAMES = [
"libero_10_no_noops",
"libero_goal_no_noops",
"libero_object_no_noops",
"libero_spatial_no_noops",
] # For simplicity we will combine multiple Libero datasets into one training dataset
def main(data_dir: str, *, push_to_hub: bool = False):
# Clean up any existing dataset in the output directory
output_path = LEROBOT_HOME / REPO_NAME
if output_path.exists():
shutil.rmtree(output_path)
# Create LeRobot dataset, define features to store
# OpenPi assumes that proprio is stored in `state` and actions in `action`
# LeRobot assumes that dtype of image data is `image`
dataset = LeRobotDataset.create(
repo_id=REPO_NAME,
robot_type="panda",
fps=10,
features={
"image": {
"dtype": "image",
"shape": (256, 256, 3),
"names": ["height", "width", "channel"],
},
"wrist_image": {
"dtype": "image",
"shape": (256, 256, 3),
"names": ["height", "width", "channel"],
},
"state": {
"dtype": "float32",
"shape": (8,),
"names": ["state"],
},
"actions": {
"dtype": "float32",
"shape": (7,),
"names": ["actions"],
},
},
image_writer_threads=10,
image_writer_processes=5,
)
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
# You can modify this for your own data format
for raw_dataset_name in RAW_DATASET_NAMES:
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
for episode in raw_dataset:
for step in episode["steps"].as_numpy_iterator():
dataset.add_frame(
{
"image": step["observation"]["image"],
"wrist_image": step["observation"]["wrist_image"],
"state": step["observation"]["state"],
"actions": step["action"],
}
)
dataset.save_episode(task=step["language_instruction"].decode())
# Consolidate the dataset, skip computing stats since we will do that later
dataset.consolidate(run_compute_stats=False)
# Optionally push to the Hugging Face Hub
if push_to_hub:
dataset.push_to_hub(
tags=["libero", "panda", "rlds"],
private=False,
push_videos=True,
license="apache-2.0",
)
if __name__ == "__main__":
tyro.cli(main)

219
examples/libero/main.py Normal file
View File

@@ -0,0 +1,219 @@
import collections
import dataclasses
import logging
import math
import pathlib
import imageio
from libero.libero import benchmark
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy
import tqdm
import tyro
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
@dataclasses.dataclass
class Args:
#################################################################################################################
# Model server parameters
#################################################################################################################
host: str = "0.0.0.0"
port: int = 8000
resize_size: int = 224
replan_steps: int = 5
#################################################################################################################
# LIBERO environment-specific parameters
#################################################################################################################
task_suite_name: str = (
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
)
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
num_trials_per_task: int = 50 # Number of rollouts per task
#################################################################################################################
# Utils
#################################################################################################################
video_out_path: str = "data/libero/videos" # Path to save videos
seed: int = 7 # Random Seed (for reproducibility)
def eval_libero(args: Args) -> None:
# Set random seed
np.random.seed(args.seed)
# Initialize LIBERO task suite
benchmark_dict = benchmark.get_benchmark_dict()
task_suite = benchmark_dict[args.task_suite_name]()
num_tasks_in_suite = task_suite.n_tasks
logging.info(f"Task suite: {args.task_suite_name}")
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
if args.task_suite_name == "libero_spatial":
max_steps = 220 # longest training demo has 193 steps
elif args.task_suite_name == "libero_object":
max_steps = 280 # longest training demo has 254 steps
elif args.task_suite_name == "libero_goal":
max_steps = 300 # longest training demo has 270 steps
elif args.task_suite_name == "libero_10":
max_steps = 520 # longest training demo has 505 steps
elif args.task_suite_name == "libero_90":
max_steps = 400 # longest training demo has 373 steps
else:
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
# Start evaluation
total_episodes, total_successes = 0, 0
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
# Get task
task = task_suite.get_task(task_id)
# Get default LIBERO initial states
initial_states = task_suite.get_task_init_states(task_id)
# Initialize LIBERO environment and task description
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
# Start episodes
task_episodes, task_successes = 0, 0
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
logging.info(f"\nTask: {task_description}")
# Reset environment
env.reset()
action_plan = collections.deque()
# Set initial states
obs = env.set_init_state(initial_states[episode_idx])
# Setup
t = 0
replay_images = []
logging.info(f"Starting episode {task_episodes+1}...")
while t < max_steps + args.num_steps_wait:
try:
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
# and we need to wait for them to fall
if t < args.num_steps_wait:
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
t += 1
continue
# Get preprocessed image
# IMPORTANT: rotate 180 degrees to match train preprocessing
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
)
wrist_img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
)
# Save preprocessed image for replay video
replay_images.append(img)
if not action_plan:
# Finished executing previous action chunk -- compute new chunk
# Prepare observations dict
element = {
"observation/image": img,
"observation/wrist_image": wrist_img,
"observation/state": np.concatenate(
(
obs["robot0_eef_pos"],
_quat2axisangle(obs["robot0_eef_quat"]),
obs["robot0_gripper_qpos"],
)
),
"prompt": str(task_description),
}
# Query model to get action
action_chunk = client.infer(element)["actions"]
assert (
len(action_chunk) >= args.replan_steps
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
action_plan.extend(action_chunk[: args.replan_steps])
action = action_plan.popleft()
# Execute action in environment
obs, reward, done, info = env.step(action.tolist())
if done:
task_successes += 1
total_successes += 1
break
t += 1
except Exception as e:
logging.error(f"Caught exception: {e}")
break
task_episodes += 1
total_episodes += 1
# Save a replay video of the episode
suffix = "success" if done else "failure"
task_segment = task_description.replace(" ", "_")
imageio.mimwrite(
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
[np.asarray(x) for x in replay_images],
fps=10,
)
# Log current results
logging.info(f"Success: {done}")
logging.info(f"# episodes completed so far: {total_episodes}")
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
# Log final results
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total episodes: {total_episodes}")
def _get_libero_env(task, resolution, seed):
"""Initializes and returns the LIBERO environment, along with the task description."""
task_description = task.language
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
env = OffScreenRenderEnv(**env_args)
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
return env, task_description
def _quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tyro.cli(eval_libero)

View File

@@ -0,0 +1,11 @@
imageio[ffmpeg]
numpy==1.22.4
tqdm
tyro
PyYaml
opencv-python==4.6.0.66
torch==1.11.0+cu113
torchvision==0.12.0+cu113
torchaudio==0.11.0+cu113
robosuite==1.4.1
matplotlib==3.5.3

View File

@@ -0,0 +1,136 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
absl-py==2.1.0
# via mujoco
certifi==2024.12.14
# via requests
charset-normalizer==3.4.0
# via requests
cycler==0.12.1
# via matplotlib
docstring-parser==0.16
# via tyro
etils==1.3.0
# via mujoco
eval-type-backport==0.2.0
# via tyro
evdev==1.7.1
# via pynput
fonttools==4.55.3
# via matplotlib
glfw==1.12.0
# via mujoco
idna==3.10
# via requests
imageio==2.35.1
# via -r examples/libero/requirements.in
imageio-ffmpeg==0.5.1
# via imageio
importlib-metadata==8.5.0
# via typeguard
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
llvmlite==0.36.0
# via numba
markdown-it-py==3.0.0
# via rich
matplotlib==3.5.3
# via -r examples/libero/requirements.in
mdurl==0.1.2
# via markdown-it-py
mujoco==3.2.3
# via robosuite
numba==0.53.1
# via robosuite
numpy==1.22.4
# via
# -r examples/libero/requirements.in
# imageio
# matplotlib
# mujoco
# numba
# opencv-python
# robosuite
# scipy
# torchvision
opencv-python==4.6.0.66
# via
# -r examples/libero/requirements.in
# robosuite
packaging==24.2
# via matplotlib
pillow==10.4.0
# via
# imageio
# matplotlib
# robosuite
# torchvision
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pynput==1.7.7
# via robosuite
pyopengl==3.1.7
# via mujoco
pyparsing==3.1.4
# via matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
python-xlib==0.33
# via pynput
pyyaml==6.0.2
# via -r examples/libero/requirements.in
requests==2.32.3
# via torchvision
rich==13.9.4
# via tyro
robosuite==1.4.1
# via -r examples/libero/requirements.in
scipy==1.10.1
# via robosuite
setuptools==75.3.0
# via
# imageio-ffmpeg
# numba
shtab==1.7.1
# via tyro
six==1.17.0
# via
# pynput
# python-dateutil
# python-xlib
termcolor==2.4.0
# via robosuite
torch==1.11.0+cu113
# via
# -r examples/libero/requirements.in
# torchaudio
# torchvision
torchaudio==0.11.0+cu113
# via -r examples/libero/requirements.in
torchvision==0.12.0+cu113
# via -r examples/libero/requirements.in
tqdm==4.67.1
# via -r examples/libero/requirements.in
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# torch
# torchvision
# typeguard
# tyro
tyro==0.9.2
# via -r examples/libero/requirements.in
urllib3==2.2.3
# via requests
zipp==3.20.2
# via
# etils
# importlib-metadata
# importlib-resources

View File

@@ -0,0 +1,134 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pathlib\n",
"\n",
"import numpy as np\n",
"\n",
"record_path = pathlib.Path(\"../policy_records\")\n",
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
"\n",
"records = []\n",
"for i in range(num_steps):\n",
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
" records.append(record)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"length of records\", len(records))\n",
"print(\"keys in records\", records[0].keys())\n",
"\n",
"for k in records[0]:\n",
" print(f\"{k} shape: {records[0][k].shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"\n",
"\n",
"def get_image(step: int, idx: int = 0):\n",
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
" return img[idx].transpose(1, 2, 0)\n",
"\n",
"\n",
"def show_image(step: int, idx_lst: list[int]):\n",
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
" return Image.fromarray(np.hstack(imgs))\n",
"\n",
"\n",
"for i in range(2):\n",
" display(show_image(i, [0]))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def get_axis(name, axis):\n",
" return np.array([record[name][axis] for record in records])\n",
"\n",
"\n",
"# qpos is [..., 14] of type float:\n",
"# 0-5: left arm joint angles\n",
"# 6: left arm gripper\n",
"# 7-12: right arm joint angles\n",
"# 13: right arm gripper\n",
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
"\n",
"\n",
"def make_data():\n",
" cur_dim = 0\n",
" in_data = {}\n",
" out_data = {}\n",
" for name, dim_size in names:\n",
" for i in range(dim_size):\n",
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
" cur_dim += 1\n",
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
"\n",
"\n",
"in_data, out_data = make_data()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in in_data.columns:\n",
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
" data.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,32 @@
# Dockerfile for the simple client.
# Build the container:
# docker build . -t simple_client -f examples/simple_client/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
FROM python:3.7-slim
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"

View File

@@ -0,0 +1,30 @@
# Simple Client
A minimal client that sends observations to the server and prints the inference rate.
You can specifiy which runtime environment to use using the `--env` flag. You can see the available options by running:
```bash
uv run examples/simple_client/main.py --help
```
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/simple_client/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
uv run examples/simple_client/main.py --env DROID
```
Terminal window 2:
```bash
uv run scripts/serve_policy.py --env DROID
```

View File

@@ -0,0 +1,42 @@
# Run with:
# docker compose -f examples/simple_client/compose.yml up --build
services:
runtime:
image: simple_client
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/simple_client/Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,89 @@
import dataclasses
import enum
import logging
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import tyro
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
@dataclasses.dataclass
class Args:
host: str = "0.0.0.0"
port: int = 8000
env: EnvMode = EnvMode.ALOHA_SIM
num_steps: int = 10
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
logging.info(f"Server metadata: {policy.get_server_metadata()}")
# Send 1 observation to make sure the model is loaded.
policy.infer(obs_fn())
start = time.time()
for _ in range(args.num_steps):
policy.infer(obs_fn())
end = time.time()
print(f"Total time taken: {end - start:.2f} s")
print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))

View File

@@ -0,0 +1,2 @@
numpy
tyro

View File

@@ -0,0 +1,27 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
backports-cached-property==1.0.2
# via tyro
docstring-parser==0.16
# via tyro
eval-type-backport==0.1.3
# via tyro
markdown-it-py==2.2.0
# via rich
mdurl==0.1.2
# via markdown-it-py
numpy==1.21.6
# via -r examples/simple_client/requirements.in
pygments==2.17.2
# via rich
rich==13.8.1
# via tyro
shtab==1.7.1
# via tyro
typing-extensions==4.7.1
# via
# markdown-it-py
# rich
# tyro
tyro==0.9.1
# via -r examples/simple_client/requirements.in