Initial commit
This commit is contained in:
70
examples/aloha_real/Dockerfile
Normal file
70
examples/aloha_real/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
126
examples/aloha_real/README.md
Normal file
126
examples/aloha_real/README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../../openpi/docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python examples/aloha_real/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch --wait aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
|
||||
```
|
||||
|
||||
## **ALOHA Checkpoint Guide**
|
||||
|
||||
|
||||
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
|
||||
|
||||
While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **Toast Task**
|
||||
|
||||
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_base`
|
||||
- **Prompt**: "take the toast out of the toaster"
|
||||
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
||||
- **Object Distribution**:
|
||||
- Works on both real toast and rubber fake toast
|
||||
- Compatible with standard 2-slice toasters
|
||||
- Works with plates of varying colors
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
||||
|
||||
- The toaster should be positioned in the top-left quadrant of the workspace.
|
||||
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
||||
- The plate should be placed roughly in the lower-center of the workspace.
|
||||
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
||||
|
||||
|
||||
### **Towel Task**
|
||||
|
||||
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_towel`
|
||||
- **Prompt**: "fold the towel"
|
||||
- **Object Distribution**:
|
||||
- Works on towels of varying solid colors
|
||||
- Performance is worse on heavily textured or striped towels
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
||||
|
||||
- The towel should be flattened and roughly centered on the table.
|
||||
- Choose a towel that does not blend in with the table surface.
|
||||
|
||||
|
||||
### **Tupperware Task**
|
||||
|
||||
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
||||
- **Prompt**: "open the tupperware and put the food on the plate"
|
||||
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
||||
- **Object Distribution**:
|
||||
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
||||
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
||||
- The policy has seen plates of varying solid colors.
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
||||
|
||||
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
||||
- Positioning:
|
||||
- Tupperware should be on the left.
|
||||
- Plate should be on the right or bottom.
|
||||
- The tupperware flap should point toward the plate.
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
||||
|
||||
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
||||
|
||||
|
||||
2. Define a training config that uses the custom dataset.
|
||||
|
||||
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
||||
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
66
examples/aloha_real/compose.yml
Normal file
66
examples/aloha_real/compose.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
examples/aloha_real/constants.py
Normal file
71
examples/aloha_real/constants.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
272
examples/aloha_real/convert_aloha_data_to_lerobot.py
Normal file
272
examples/aloha_real/convert_aloha_data_to_lerobot.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
||||
|
||||
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Literal
|
||||
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DatasetConfig:
|
||||
use_videos: bool = True
|
||||
tolerance_s: float = 0.0001
|
||||
image_writer_processes: int = 10
|
||||
image_writer_threads: int = 5
|
||||
video_backend: str | None = None
|
||||
|
||||
|
||||
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
||||
|
||||
|
||||
def create_empty_dataset(
|
||||
repo_id: str,
|
||||
robot_type: str,
|
||||
mode: Literal["video", "image"] = "video",
|
||||
*,
|
||||
has_velocity: bool = False,
|
||||
has_effort: bool = False,
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
) -> LeRobotDataset:
|
||||
motors = [
|
||||
"right_waist",
|
||||
"right_shoulder",
|
||||
"right_elbow",
|
||||
"right_forearm_roll",
|
||||
"right_wrist_angle",
|
||||
"right_wrist_rotate",
|
||||
"right_gripper",
|
||||
"left_waist",
|
||||
"left_shoulder",
|
||||
"left_elbow",
|
||||
"left_forearm_roll",
|
||||
"left_wrist_angle",
|
||||
"left_wrist_rotate",
|
||||
"left_gripper",
|
||||
]
|
||||
cameras = [
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
]
|
||||
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
if has_velocity:
|
||||
features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
if has_effort:
|
||||
features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
for cam in cameras:
|
||||
features[f"observation.images.{cam}"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 480, 640),
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
|
||||
if Path(LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=50,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
use_videos=dataset_config.use_videos,
|
||||
tolerance_s=dataset_config.tolerance_s,
|
||||
image_writer_processes=dataset_config.image_writer_processes,
|
||||
image_writer_threads=dataset_config.image_writer_threads,
|
||||
video_backend=dataset_config.video_backend,
|
||||
)
|
||||
|
||||
|
||||
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
# ignore depth channel, not currently handled
|
||||
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
|
||||
|
||||
def has_velocity(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/qvel" in ep
|
||||
|
||||
|
||||
def has_effort(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/effort" in ep
|
||||
|
||||
|
||||
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
||||
imgs_per_cam = {}
|
||||
for camera in cameras:
|
||||
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||
|
||||
if uncompressed:
|
||||
# load all images in RAM
|
||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.imdecode(data, 1))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
return imgs_per_cam
|
||||
|
||||
|
||||
def load_raw_episode_data(
|
||||
ep_path: Path,
|
||||
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
|
||||
effort = None
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
imgs_per_cam = load_raw_images_per_camera(
|
||||
ep,
|
||||
[
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
],
|
||||
)
|
||||
|
||||
return imgs_per_cam, state, action, velocity, effort
|
||||
|
||||
|
||||
def populate_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
hdf5_files: list[Path],
|
||||
task: str,
|
||||
episodes: list[int] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
if episodes is None:
|
||||
episodes = range(len(hdf5_files))
|
||||
|
||||
for ep_idx in tqdm.tqdm(episodes):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
|
||||
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||
num_frames = state.shape[0]
|
||||
|
||||
for i in range(num_frames):
|
||||
frame = {
|
||||
"observation.state": state[i],
|
||||
"action": action[i],
|
||||
}
|
||||
|
||||
for camera, img_array in imgs_per_cam.items():
|
||||
frame[f"observation.images.{camera}"] = img_array[i]
|
||||
|
||||
if velocity is not None:
|
||||
frame["observation.velocity"] = velocity[i]
|
||||
if effort is not None:
|
||||
frame["observation.effort"] = effort[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=task)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_aloha(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
raw_repo_id: str | None = None,
|
||||
task: str = "DEBUG",
|
||||
*,
|
||||
episodes: list[int] | None = None,
|
||||
push_to_hub: bool = True,
|
||||
is_mobile: bool = False,
|
||||
mode: Literal["video", "image"] = "image",
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
):
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
if raw_repo_id is None:
|
||||
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
||||
download_raw(raw_dir, repo_id=raw_repo_id)
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
|
||||
dataset = create_empty_dataset(
|
||||
repo_id,
|
||||
robot_type="mobile_aloha" if is_mobile else "aloha",
|
||||
mode=mode,
|
||||
has_effort=has_effort(hdf5_files),
|
||||
has_velocity=has_velocity(hdf5_files),
|
||||
dataset_config=dataset_config,
|
||||
)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
hdf5_files,
|
||||
task=task,
|
||||
episodes=episodes,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(port_aloha)
|
||||
57
examples/aloha_real/env.py
Normal file
57
examples/aloha_real/env.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List, Optional # noqa: UP035
|
||||
|
||||
import einops
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
|
||||
render_height: int = 224,
|
||||
render_width: int = 224,
|
||||
) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
for cam_name in obs["images"]:
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
|
||||
)
|
||||
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
|
||||
|
||||
return {
|
||||
"state": obs["qpos"],
|
||||
"images": obs["images"],
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["actions"])
|
||||
51
examples/aloha_real/main.py
Normal file
51
examples/aloha_real/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
num_episodes: int = 1
|
||||
max_episode_steps: int = 1000
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
||||
|
||||
metadata = ws_client_policy.get_server_metadata()
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=ws_client_policy,
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
num_episodes=args.num_episodes,
|
||||
max_episode_steps=args.max_episode_steps,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
171
examples/aloha_real/real_env.py
Normal file
171
examples/aloha_real/real_env.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
from typing import Optional, List
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
# This is the reset position that is used by the standard Aloha runtime.
|
||||
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
|
||||
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
|
||||
18
examples/aloha_real/requirements.in
Normal file
18
examples/aloha_real/requirements.in
Normal file
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
examples/aloha_real/requirements.txt
Normal file
156
examples/aloha_real/requirements.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
examples/aloha_real/robot_utils.py
Normal file
275
examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
36
examples/aloha_real/video_display.py
Normal file
36
examples/aloha_real/video_display.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
41
examples/aloha_sim/Dockerfile
Normal file
41
examples/aloha_sim/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for the Aloha simulation environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
||||
|
||||
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev
|
||||
ENV MUJOCO_GL=egl
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
||||
36
examples/aloha_sim/README.md
Normal file
36
examples/aloha_sim/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Run Aloha Sim
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_sim/.venv
|
||||
source examples/aloha_sim/.venv/bin/activate
|
||||
uv pip sync examples/aloha_sim/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the simulation
|
||||
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
||||
```
|
||||
|
||||
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env ALOHA_SIM
|
||||
```
|
||||
42
examples/aloha_sim/compose.yml
Normal file
42
examples/aloha_sim/compose.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_sim
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_sim/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
56
examples/aloha_sim/env.py
Normal file
56
examples/aloha_sim/env.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import gym_aloha # noqa: F401
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class AlohaSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self._gym = gymnasium.make(task, obs_type=obs_type)
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
img = gym_obs["pixels"]["top"]
|
||||
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
|
||||
return {
|
||||
"state": gym_obs["agent_pos"],
|
||||
"images": {"cam_high": img},
|
||||
}
|
||||
55
examples/aloha_sim/main.py
Normal file
55
examples/aloha_sim/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import env as _env
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import saver as _saver
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
|
||||
|
||||
task: str = "gym_aloha/AlohaTransferCube-v0"
|
||||
seed: int = 0
|
||||
|
||||
action_horizon: int = 10
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
display: bool = False
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaSimEnvironment(
|
||||
task=args.task,
|
||||
seed=args.seed,
|
||||
),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[
|
||||
_saver.VideoSaver(args.out_dir),
|
||||
],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
8
examples/aloha_sim/requirements.in
Normal file
8
examples/aloha_sim/requirements.in
Normal file
@@ -0,0 +1,8 @@
|
||||
gym-aloha
|
||||
imageio
|
||||
matplotlib
|
||||
msgpack
|
||||
numpy
|
||||
typing-extensions
|
||||
tyro
|
||||
websockets
|
||||
132
examples/aloha_sim/requirements.txt
Normal file
132
examples/aloha_sim/requirements.txt
Normal file
@@ -0,0 +1,132 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cloudpickle==3.1.0
|
||||
# via gymnasium
|
||||
contourpy==1.3.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
gym-aloha==0.1.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
gymnasium==1.0.0
|
||||
# via gym-aloha
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.36.1
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gym-aloha
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.9.3
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# gymnasium
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# scipy
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==11.0.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.0
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
requests==2.32.3
|
||||
# via dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
scipy==1.14.1
|
||||
# via dm-control
|
||||
setuptools==75.6.0
|
||||
# via
|
||||
# dm-control
|
||||
# imageio-ffmpeg
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.1
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gymnasium
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
40
examples/aloha_sim/saver.py
Normal file
40
examples/aloha_sim/saver.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoSaver(_subscriber.Subscriber):
|
||||
"""Saves episode data."""
|
||||
|
||||
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._out_dir = out_dir
|
||||
self._images: list[np.ndarray] = []
|
||||
self._subsample = subsample
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
self._images = []
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
im = observation["images"]["cam_high"] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
self._images.append(im)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
|
||||
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
|
||||
out_path = self._out_dir / f"out_{next_idx}.mp4"
|
||||
|
||||
logging.info(f"Saving video to {out_path}")
|
||||
imageio.mimwrite(
|
||||
out_path,
|
||||
[np.asarray(x) for x in self._images[:: self._subsample]],
|
||||
fps=50 // max(1, self._subsample),
|
||||
)
|
||||
46
examples/droid/README.md
Normal file
46
examples/droid/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Run DROID
|
||||
|
||||
This example shows how to run the fine-tuned $\pi_0$-FAST-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). We also offer a $\pi_0$-DROID model that is fine-tuned from $\pi_0$ and uses flow action decoding. You can use it by replacing `pi0_fast_droid` with `pi0_droid` in the commands below. In practice, we find that out-of-the-box, the $\pi_0$-FAST-DROID model is better at following language commands, so we recommend it as the default checkpoint for DROID evaluation. If you want to fine-tune on a DROID task that requires a fast-to-inference policy, you may still want to consider using the $\pi_0$-DROID model, since it decodes faster. For more details, please see the [FAST paper](https://pi.website/research/fast).
|
||||
|
||||
|
||||
## Step 1: Start a policy server
|
||||
|
||||
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
|
||||
|
||||
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
|
||||
2. Start the OpenPI server via the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid
|
||||
```
|
||||
|
||||
You can also run the equivalent command below:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env=DROID
|
||||
```
|
||||
|
||||
## Step 2: Run the DROID robot
|
||||
|
||||
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
|
||||
2. On the control laptop, activate your DROID conda environment.
|
||||
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
||||
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
||||
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
||||
|
||||
```bash
|
||||
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
|
||||
```
|
||||
|
||||
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
|
||||
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
|
||||
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
|
||||
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
|
||||
237
examples/droid/main.py
Normal file
237
examples/droid/main.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# ruff: noqa
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import datetime
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from droid.robot_env import RobotEnv
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
# Hardware parameters
|
||||
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
|
||||
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
|
||||
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
|
||||
|
||||
# Policy parameters
|
||||
external_camera: str | None = (
|
||||
None # which external camera should be fed to the policy, choose from ["left", "right"]
|
||||
)
|
||||
|
||||
# Rollout parameters
|
||||
max_timesteps: int = 600
|
||||
# How many actions to execute from a predicted action chunk before querying policy server again
|
||||
# 8 is usually a good default (equals 0.5 seconds of action execution).
|
||||
open_loop_horizon: int = 8
|
||||
|
||||
# Remote server parameters
|
||||
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
|
||||
remote_port: int = (
|
||||
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
|
||||
)
|
||||
|
||||
|
||||
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
||||
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
||||
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
||||
@contextlib.contextmanager
|
||||
def prevent_keyboard_interrupt():
|
||||
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
||||
interrupted = False
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def handler(signum, frame):
|
||||
nonlocal interrupted
|
||||
interrupted = True
|
||||
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if interrupted:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
# Make sure external camera is specified by user -- we only use one external camera for the policy
|
||||
assert (
|
||||
args.external_camera is not None and args.external_camera in ["left", "right"]
|
||||
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
||||
|
||||
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
|
||||
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
||||
print("Created the droid env!")
|
||||
|
||||
# Connect to the policy server
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
||||
|
||||
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
||||
|
||||
while True:
|
||||
instruction = input("Enter instruction: ")
|
||||
|
||||
# Rollout parameters
|
||||
actions_from_chunk_completed = 0
|
||||
pred_action_chunk = None
|
||||
|
||||
# Prepare to save video of rollout
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
||||
video = []
|
||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||
print("Running rollout... press Ctrl+C to stop early.")
|
||||
for t_step in bar:
|
||||
try:
|
||||
# Get the current observation
|
||||
curr_obs = _extract_observation(
|
||||
args,
|
||||
env.get_observation(),
|
||||
# Save the first observation to disk
|
||||
save_to_disk=t_step == 0,
|
||||
)
|
||||
|
||||
video.append(curr_obs[f"{args.external_camera}_image"])
|
||||
|
||||
# Send websocket request to policy server if it's time to predict a new chunk
|
||||
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
||||
actions_from_chunk_completed = 0
|
||||
|
||||
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
|
||||
# and improve latency.
|
||||
request_data = {
|
||||
"observation/exterior_image_1_left": image_tools.resize_with_pad(
|
||||
curr_obs[f"{args.external_camera}_image"], 224, 224
|
||||
),
|
||||
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
||||
"observation/joint_position": curr_obs["joint_position"],
|
||||
"observation/gripper_position": curr_obs["gripper_position"],
|
||||
"prompt": instruction,
|
||||
}
|
||||
|
||||
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
||||
# Ctrl+C will be handled after the server call is complete
|
||||
with prevent_keyboard_interrupt():
|
||||
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
|
||||
pred_action_chunk = policy_client.infer(request_data)["actions"]
|
||||
assert pred_action_chunk.shape == (10, 8)
|
||||
|
||||
# Select current action to execute from chunk
|
||||
action = pred_action_chunk[actions_from_chunk_completed]
|
||||
actions_from_chunk_completed += 1
|
||||
|
||||
# Binarize gripper action
|
||||
if action[-1].item() > 0.5:
|
||||
# action[-1] = 1.0
|
||||
action = np.concatenate([action[:-1], np.ones((1,))])
|
||||
else:
|
||||
# action[-1] = 0.0
|
||||
action = np.concatenate([action[:-1], np.zeros((1,))])
|
||||
|
||||
# clip all dimensions of action to [-1, 1]
|
||||
action = np.clip(action, -1, 1)
|
||||
|
||||
env.step(action)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
video = np.stack(video)
|
||||
save_filename = "video_" + timestamp
|
||||
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
||||
|
||||
success: str | float | None = None
|
||||
while not isinstance(success, float):
|
||||
success = input(
|
||||
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
||||
)
|
||||
if success == "y":
|
||||
success = 1.0
|
||||
elif success == "n":
|
||||
success = 0.0
|
||||
|
||||
success = float(success) / 100
|
||||
if not (0 <= success <= 1):
|
||||
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
||||
|
||||
df = df.append(
|
||||
{
|
||||
"success": success,
|
||||
"duration": t_step,
|
||||
"video_filename": save_filename,
|
||||
},
|
||||
ignore_index=True,
|
||||
)
|
||||
|
||||
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
||||
break
|
||||
env.reset()
|
||||
|
||||
os.makedirs("results", exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
||||
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
||||
df.to_csv(csv_filename)
|
||||
print(f"Results saved to {csv_filename}")
|
||||
|
||||
|
||||
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
||||
image_observations = obs_dict["image"]
|
||||
left_image, right_image, wrist_image = None, None, None
|
||||
for key in image_observations:
|
||||
# Note the "left" below refers to the left camera in the stereo pair.
|
||||
# The model is only trained on left stereo cams, so we only feed those.
|
||||
if args.left_camera_id in key and "left" in key:
|
||||
left_image = image_observations[key]
|
||||
elif args.right_camera_id in key and "left" in key:
|
||||
right_image = image_observations[key]
|
||||
elif args.wrist_camera_id in key and "left" in key:
|
||||
wrist_image = image_observations[key]
|
||||
|
||||
# Drop the alpha dimension
|
||||
left_image = left_image[..., :3]
|
||||
right_image = right_image[..., :3]
|
||||
wrist_image = wrist_image[..., :3]
|
||||
|
||||
# Convert to RGB
|
||||
left_image = left_image[..., ::-1]
|
||||
right_image = right_image[..., ::-1]
|
||||
wrist_image = wrist_image[..., ::-1]
|
||||
|
||||
# In addition to image observations, also capture the proprioceptive state
|
||||
robot_state = obs_dict["robot_state"]
|
||||
cartesian_position = np.array(robot_state["cartesian_position"])
|
||||
joint_position = np.array(robot_state["joint_positions"])
|
||||
gripper_position = np.array([robot_state["gripper_position"]])
|
||||
|
||||
# Save the images to disk so that they can be viewed live while the robot is running
|
||||
# Create one combined image to make live viewing easy
|
||||
if save_to_disk:
|
||||
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
||||
combined_image = Image.fromarray(combined_image)
|
||||
combined_image.save("robot_camera_views.png")
|
||||
|
||||
return {
|
||||
"left_image": left_image,
|
||||
"right_image": right_image,
|
||||
"wrist_image": wrist_image,
|
||||
"cartesian_position": cartesian_position,
|
||||
"joint_position": joint_position,
|
||||
"gripper_position": gripper_position,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args: Args = tyro.cli(Args)
|
||||
main(args)
|
||||
137
examples/inference.ipynb
Normal file
137
examples/inference.ipynb
Normal file
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import dataclasses\n",
|
||||
"\n",
|
||||
"import jax\n",
|
||||
"\n",
|
||||
"from openpi.models import model as _model\n",
|
||||
"from openpi.policies import droid_policy\n",
|
||||
"from openpi.policies import policy_config as _policy_config\n",
|
||||
"from openpi.shared import download\n",
|
||||
"from openpi.training import config as _config\n",
|
||||
"from openpi.training import data_loader as _data_loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Policy inference\n",
|
||||
"\n",
|
||||
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
||||
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
||||
"\n",
|
||||
"# Create a trained policy.\n",
|
||||
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
||||
"\n",
|
||||
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
|
||||
"example = droid_policy.make_droid_example()\n",
|
||||
"result = policy.infer(example)\n",
|
||||
"\n",
|
||||
"# Delete the policy to free up memory.\n",
|
||||
"del policy\n",
|
||||
"\n",
|
||||
"print(\"Actions shape:\", result[\"actions\"].shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Working with a live model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_aloha_sim\")\n",
|
||||
"\n",
|
||||
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
|
||||
"key = jax.random.key(0)\n",
|
||||
"\n",
|
||||
"# Create a model from the checkpoint.\n",
|
||||
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
|
||||
"\n",
|
||||
"# We can create fake observations and actions to test the model.\n",
|
||||
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reduce the batch size to reduce memory usage.\n",
|
||||
"config = dataclasses.replace(config, batch_size=2)\n",
|
||||
"\n",
|
||||
"# Load a single batch of data. This is the same data that will be used during training.\n",
|
||||
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
|
||||
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
|
||||
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
|
||||
"obs, act = next(iter(loader))\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"\n",
|
||||
"# Delete the model to free up memory.\n",
|
||||
"del model\n",
|
||||
"\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
59
examples/libero/Dockerfile
Normal file
59
examples/libero/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Dockerfile for the LIBERO benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t libero -f examples/libero/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
make \
|
||||
g++ \
|
||||
clang \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
|
||||
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
|
||||
|
||||
# Create a default config file to avoid an input prompt from LIBERO's init script.
|
||||
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
|
||||
ENV LIBERO_CONFIG_PATH=/tmp/libero
|
||||
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
|
||||
benchmark_root: /app/third_party/libero/libero/libero
|
||||
bddl_files: /app/third_party/libero/libero/libero/bddl_files
|
||||
init_states: /app/third_party/libero/libero/libero/init_files
|
||||
datasets: /app/third_party/libero/libero/datasets
|
||||
assets: /app/third_party/libero/libero/libero/assets
|
||||
EOF
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"]
|
||||
56
examples/libero/README.md
Normal file
56
examples/libero/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# LIBERO Benchmark
|
||||
|
||||
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
|
||||
|
||||
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
|
||||
|
||||
This example requires git submodules to be initialized. Don't forget to run:
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
# Grant access to the X11 server:
|
||||
sudo xhost +local:docker
|
||||
|
||||
export SERVER_ARGS="--env LIBERO"
|
||||
docker compose -f examples/libero/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.8 examples/libero/.venv
|
||||
source examples/libero/.venv/bin/activate
|
||||
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
uv pip install -e packages/openpi-client
|
||||
uv pip install -e third_party/libero
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
|
||||
|
||||
# Run the simulation
|
||||
python examples/libero/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env LIBERO
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
If you follow the training instructions and hyperparameters in the `pi0_libero` and `pi0_fast_libero` configs, you should get results similar to the following:
|
||||
|
||||
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|
||||
|-------|---------------|---------------|-------------|-----------|---------|
|
||||
| π0-FAST @ 30k (finetuned) | 96.4 | 96.8 | 88.6 | 60.2 | 85.5 |
|
||||
| π0 @ 30k (finetuned) | 96.8 | 98.8 | 95.8 | 85.2 | 94.15 |
|
||||
|
||||
Note that the hyperparameters for these runs are not tuned and $\pi_0$-FAST does not use a FAST tokenizer optimized for Libero. Likely, the results could be improved with more tuning, we mainly use these results as an example of how to use openpi to fine-tune $\pi_0$ models on a new dataset.
|
||||
52
examples/libero/compose.yml
Normal file
52
examples/libero/compose.yml
Normal file
@@ -0,0 +1,52 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/libero/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: libero
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/libero/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix:ro
|
||||
environment:
|
||||
- DISPLAY=$DISPLAY
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
106
examples/libero/convert_libero_data_to_lerobot.py
Normal file
106
examples/libero/convert_libero_data_to_lerobot.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Minimal example script for converting a dataset to LeRobot format.
|
||||
|
||||
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
|
||||
modified for any other data you have saved in a custom format.
|
||||
|
||||
Usage:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
|
||||
|
||||
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
||||
|
||||
Note: to run the script, you need to install tensorflow_datasets:
|
||||
`uv pip install tensorflow tensorflow_datasets`
|
||||
|
||||
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
|
||||
The resulting dataset will get saved to the $LEROBOT_HOME directory.
|
||||
Running this conversion script will take approximately 30 minutes.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import tensorflow_datasets as tfds
|
||||
import tyro
|
||||
|
||||
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
|
||||
RAW_DATASET_NAMES = [
|
||||
"libero_10_no_noops",
|
||||
"libero_goal_no_noops",
|
||||
"libero_object_no_noops",
|
||||
"libero_spatial_no_noops",
|
||||
] # For simplicity we will combine multiple Libero datasets into one training dataset
|
||||
|
||||
|
||||
def main(data_dir: str, *, push_to_hub: bool = False):
|
||||
# Clean up any existing dataset in the output directory
|
||||
output_path = LEROBOT_HOME / REPO_NAME
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
# Create LeRobot dataset, define features to store
|
||||
# OpenPi assumes that proprio is stored in `state` and actions in `action`
|
||||
# LeRobot assumes that dtype of image data is `image`
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=REPO_NAME,
|
||||
robot_type="panda",
|
||||
fps=10,
|
||||
features={
|
||||
"image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"wrist_image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": ["state"],
|
||||
},
|
||||
"actions": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": ["actions"],
|
||||
},
|
||||
},
|
||||
image_writer_threads=10,
|
||||
image_writer_processes=5,
|
||||
)
|
||||
|
||||
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
|
||||
# You can modify this for your own data format
|
||||
for raw_dataset_name in RAW_DATASET_NAMES:
|
||||
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
|
||||
for episode in raw_dataset:
|
||||
for step in episode["steps"].as_numpy_iterator():
|
||||
dataset.add_frame(
|
||||
{
|
||||
"image": step["observation"]["image"],
|
||||
"wrist_image": step["observation"]["wrist_image"],
|
||||
"state": step["observation"]["state"],
|
||||
"actions": step["action"],
|
||||
}
|
||||
)
|
||||
dataset.save_episode(task=step["language_instruction"].decode())
|
||||
|
||||
# Consolidate the dataset, skip computing stats since we will do that later
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
|
||||
# Optionally push to the Hugging Face Hub
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(
|
||||
tags=["libero", "panda", "rlds"],
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
219
examples/libero/main.py
Normal file
219
examples/libero/main.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
from libero.libero import benchmark
|
||||
from libero.libero import get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
|
||||
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
resize_size: int = 224
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# LIBERO environment-specific parameters
|
||||
#################################################################################################################
|
||||
task_suite_name: str = (
|
||||
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
||||
)
|
||||
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
|
||||
num_trials_per_task: int = 50 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/libero/videos" # Path to save videos
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def eval_libero(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize LIBERO task suite
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[args.task_suite_name]()
|
||||
num_tasks_in_suite = task_suite.n_tasks
|
||||
logging.info(f"Task suite: {args.task_suite_name}")
|
||||
|
||||
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.task_suite_name == "libero_spatial":
|
||||
max_steps = 220 # longest training demo has 193 steps
|
||||
elif args.task_suite_name == "libero_object":
|
||||
max_steps = 280 # longest training demo has 254 steps
|
||||
elif args.task_suite_name == "libero_goal":
|
||||
max_steps = 300 # longest training demo has 270 steps
|
||||
elif args.task_suite_name == "libero_10":
|
||||
max_steps = 520 # longest training demo has 505 steps
|
||||
elif args.task_suite_name == "libero_90":
|
||||
max_steps = 400 # longest training demo has 373 steps
|
||||
else:
|
||||
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation
|
||||
total_episodes, total_successes = 0, 0
|
||||
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
||||
# Get task
|
||||
task = task_suite.get_task(task_id)
|
||||
|
||||
# Get default LIBERO initial states
|
||||
initial_states = task_suite.get_task_init_states(task_id)
|
||||
|
||||
# Initialize LIBERO environment and task description
|
||||
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
|
||||
|
||||
# Start episodes
|
||||
task_episodes, task_successes = 0, 0
|
||||
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
|
||||
logging.info(f"\nTask: {task_description}")
|
||||
|
||||
# Reset environment
|
||||
env.reset()
|
||||
action_plan = collections.deque()
|
||||
|
||||
# Set initial states
|
||||
obs = env.set_init_state(initial_states[episode_idx])
|
||||
|
||||
# Setup
|
||||
t = 0
|
||||
replay_images = []
|
||||
|
||||
logging.info(f"Starting episode {task_episodes+1}...")
|
||||
while t < max_steps + args.num_steps_wait:
|
||||
try:
|
||||
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
||||
# and we need to wait for them to fall
|
||||
if t < args.num_steps_wait:
|
||||
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# Get preprocessed image
|
||||
# IMPORTANT: rotate 180 degrees to match train preprocessing
|
||||
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
|
||||
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
|
||||
)
|
||||
wrist_img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
|
||||
)
|
||||
|
||||
# Save preprocessed image for replay video
|
||||
replay_images.append(img)
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/image": img,
|
||||
"observation/wrist_image": wrist_img,
|
||||
"observation/state": np.concatenate(
|
||||
(
|
||||
obs["robot0_eef_pos"],
|
||||
_quat2axisangle(obs["robot0_eef_quat"]),
|
||||
obs["robot0_gripper_qpos"],
|
||||
)
|
||||
),
|
||||
"prompt": str(task_description),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Execute action in environment
|
||||
obs, reward, done, info = env.step(action.tolist())
|
||||
if done:
|
||||
task_successes += 1
|
||||
total_successes += 1
|
||||
break
|
||||
t += 1
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Caught exception: {e}")
|
||||
break
|
||||
|
||||
task_episodes += 1
|
||||
total_episodes += 1
|
||||
|
||||
# Save a replay video of the episode
|
||||
suffix = "success" if done else "failure"
|
||||
task_segment = task_description.replace(" ", "_")
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
|
||||
[np.asarray(x) for x in replay_images],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
# Log current results
|
||||
logging.info(f"Success: {done}")
|
||||
logging.info(f"# episodes completed so far: {total_episodes}")
|
||||
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
||||
|
||||
# Log final results
|
||||
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
||||
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
|
||||
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
logging.info(f"Total episodes: {total_episodes}")
|
||||
|
||||
|
||||
def _get_libero_env(task, resolution, seed):
|
||||
"""Initializes and returns the LIBERO environment, along with the task description."""
|
||||
task_description = task.language
|
||||
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
|
||||
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
||||
return env, task_description
|
||||
|
||||
|
||||
def _quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(eval_libero)
|
||||
11
examples/libero/requirements.in
Normal file
11
examples/libero/requirements.in
Normal file
@@ -0,0 +1,11 @@
|
||||
imageio[ffmpeg]
|
||||
numpy==1.22.4
|
||||
tqdm
|
||||
tyro
|
||||
PyYaml
|
||||
opencv-python==4.6.0.66
|
||||
torch==1.11.0+cu113
|
||||
torchvision==0.12.0+cu113
|
||||
torchaudio==0.11.0+cu113
|
||||
robosuite==1.4.1
|
||||
matplotlib==3.5.3
|
||||
136
examples/libero/requirements.txt
Normal file
136
examples/libero/requirements.txt
Normal file
@@ -0,0 +1,136 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
|
||||
absl-py==2.1.0
|
||||
# via mujoco
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
eval-type-backport==0.2.0
|
||||
# via tyro
|
||||
evdev==1.7.1
|
||||
# via pynput
|
||||
fonttools==4.55.3
|
||||
# via matplotlib
|
||||
glfw==1.12.0
|
||||
# via mujoco
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.35.1
|
||||
# via -r examples/libero/requirements.in
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
importlib-metadata==8.5.0
|
||||
# via typeguard
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
llvmlite==0.36.0
|
||||
# via numba
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.5.3
|
||||
# via -r examples/libero/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mujoco==3.2.3
|
||||
# via robosuite
|
||||
numba==0.53.1
|
||||
# via robosuite
|
||||
numpy==1.22.4
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# imageio
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# robosuite
|
||||
# scipy
|
||||
# torchvision
|
||||
opencv-python==4.6.0.66
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# robosuite
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
# robosuite
|
||||
# torchvision
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pynput==1.7.7
|
||||
# via robosuite
|
||||
pyopengl==3.1.7
|
||||
# via mujoco
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pyyaml==6.0.2
|
||||
# via -r examples/libero/requirements.in
|
||||
requests==2.32.3
|
||||
# via torchvision
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
robosuite==1.4.1
|
||||
# via -r examples/libero/requirements.in
|
||||
scipy==1.10.1
|
||||
# via robosuite
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# imageio-ffmpeg
|
||||
# numba
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
termcolor==2.4.0
|
||||
# via robosuite
|
||||
torch==1.11.0+cu113
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# torchaudio
|
||||
# torchvision
|
||||
torchaudio==0.11.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
torchvision==0.12.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
tqdm==4.67.1
|
||||
# via -r examples/libero/requirements.in
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# torch
|
||||
# torchvision
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/libero/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
zipp==3.20.2
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
134
examples/policy_records.ipynb
Normal file
134
examples/policy_records.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"record_path = pathlib.Path(\"../policy_records\")\n",
|
||||
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
|
||||
"\n",
|
||||
"records = []\n",
|
||||
"for i in range(num_steps):\n",
|
||||
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
|
||||
" records.append(record)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"length of records\", len(records))\n",
|
||||
"print(\"keys in records\", records[0].keys())\n",
|
||||
"\n",
|
||||
"for k in records[0]:\n",
|
||||
" print(f\"{k} shape: {records[0][k].shape}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_image(step: int, idx: int = 0):\n",
|
||||
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
|
||||
" return img[idx].transpose(1, 2, 0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_image(step: int, idx_lst: list[int]):\n",
|
||||
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
|
||||
" return Image.fromarray(np.hstack(imgs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(2):\n",
|
||||
" display(show_image(i, [0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_axis(name, axis):\n",
|
||||
" return np.array([record[name][axis] for record in records])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# qpos is [..., 14] of type float:\n",
|
||||
"# 0-5: left arm joint angles\n",
|
||||
"# 6: left arm gripper\n",
|
||||
"# 7-12: right arm joint angles\n",
|
||||
"# 13: right arm gripper\n",
|
||||
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_data():\n",
|
||||
" cur_dim = 0\n",
|
||||
" in_data = {}\n",
|
||||
" out_data = {}\n",
|
||||
" for name, dim_size in names:\n",
|
||||
" for i in range(dim_size):\n",
|
||||
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
|
||||
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
|
||||
" cur_dim += 1\n",
|
||||
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"in_data, out_data = make_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in in_data.columns:\n",
|
||||
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
|
||||
" data.plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
32
examples/simple_client/Dockerfile
Normal file
32
examples/simple_client/Dockerfile
Normal file
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
|
||||
30
examples/simple_client/README.md
Normal file
30
examples/simple_client/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specifiy which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --env DROID
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
42
examples/simple_client/compose.yml
Normal file
42
examples/simple_client/compose.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
89
examples/simple_client/main.py
Normal file
89
examples/simple_client/main.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tyro
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
num_steps: int = 10
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
EnvMode.ALOHA: _random_observation_aloha,
|
||||
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
||||
EnvMode.DROID: _random_observation_droid,
|
||||
EnvMode.LIBERO: _random_observation_libero,
|
||||
}[args.env]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {policy.get_server_metadata()}")
|
||||
|
||||
# Send 1 observation to make sure the model is loaded.
|
||||
policy.infer(obs_fn())
|
||||
|
||||
start = time.time()
|
||||
for _ in range(args.num_steps):
|
||||
policy.infer(obs_fn())
|
||||
end = time.time()
|
||||
|
||||
print(f"Total time taken: {end - start:.2f} s")
|
||||
print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main(tyro.cli(Args))
|
||||
2
examples/simple_client/requirements.in
Normal file
2
examples/simple_client/requirements.in
Normal file
@@ -0,0 +1,2 @@
|
||||
numpy
|
||||
tyro
|
||||
27
examples/simple_client/requirements.txt
Normal file
27
examples/simple_client/requirements.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
|
||||
backports-cached-property==1.0.2
|
||||
# via tyro
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
eval-type-backport==0.1.3
|
||||
# via tyro
|
||||
markdown-it-py==2.2.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.21.6
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.17.2
|
||||
# via rich
|
||||
rich==13.8.1
|
||||
# via tyro
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
typing-extensions==4.7.1
|
||||
# via
|
||||
# markdown-it-py
|
||||
# rich
|
||||
# tyro
|
||||
tyro==0.9.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
Reference in New Issue
Block a user