multi-node openpi commit
This commit is contained in:
70
policy/openpi-InternData-A1/examples/aloha_real/Dockerfile
Normal file
70
policy/openpi-InternData-A1/examples/aloha_real/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
126
policy/openpi-InternData-A1/examples/aloha_real/README.md
Normal file
126
policy/openpi-InternData-A1/examples/aloha_real/README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python -m examples.aloha_real.main
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
|
||||
```
|
||||
|
||||
## **ALOHA Checkpoint Guide**
|
||||
|
||||
|
||||
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
|
||||
|
||||
While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **Toast Task**
|
||||
|
||||
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
||||
|
||||
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
|
||||
- **Prompt**: "take the toast out of the toaster"
|
||||
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
||||
- **Object Distribution**:
|
||||
- Works on both real toast and rubber fake toast
|
||||
- Compatible with standard 2-slice toasters
|
||||
- Works with plates of varying colors
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
||||
|
||||
- The toaster should be positioned in the top-left quadrant of the workspace.
|
||||
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
||||
- The plate should be placed roughly in the lower-center of the workspace.
|
||||
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
||||
|
||||
|
||||
### **Towel Task**
|
||||
|
||||
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
||||
|
||||
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
|
||||
- **Prompt**: "fold the towel"
|
||||
- **Object Distribution**:
|
||||
- Works on towels of varying solid colors
|
||||
- Performance is worse on heavily textured or striped towels
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
||||
|
||||
- The towel should be flattened and roughly centered on the table.
|
||||
- Choose a towel that does not blend in with the table surface.
|
||||
|
||||
|
||||
### **Tupperware Task**
|
||||
|
||||
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
||||
|
||||
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
||||
- **Prompt**: "open the tupperware and put the food on the plate"
|
||||
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
||||
- **Object Distribution**:
|
||||
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
||||
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
||||
- The policy has seen plates of varying solid colors.
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
||||
|
||||
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
||||
- Positioning:
|
||||
- Tupperware should be on the left.
|
||||
- Plate should be on the right or bottom.
|
||||
- The tupperware flap should point toward the plate.
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
||||
|
||||
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
||||
|
||||
|
||||
2. Define a training config that uses the custom dataset.
|
||||
|
||||
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
||||
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
66
policy/openpi-InternData-A1/examples/aloha_real/compose.yml
Normal file
66
policy/openpi-InternData-A1/examples/aloha_real/compose.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
policy/openpi-InternData-A1/examples/aloha_real/constants.py
Normal file
71
policy/openpi-InternData-A1/examples/aloha_real/constants.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
||||
|
||||
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Literal
|
||||
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DatasetConfig:
|
||||
use_videos: bool = True
|
||||
tolerance_s: float = 0.0001
|
||||
image_writer_processes: int = 10
|
||||
image_writer_threads: int = 5
|
||||
video_backend: str | None = None
|
||||
|
||||
|
||||
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
||||
|
||||
|
||||
def create_empty_dataset(
|
||||
repo_id: str,
|
||||
robot_type: str,
|
||||
mode: Literal["video", "image"] = "video",
|
||||
*,
|
||||
has_velocity: bool = False,
|
||||
has_effort: bool = False,
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
) -> LeRobotDataset:
|
||||
motors = [
|
||||
"right_waist",
|
||||
"right_shoulder",
|
||||
"right_elbow",
|
||||
"right_forearm_roll",
|
||||
"right_wrist_angle",
|
||||
"right_wrist_rotate",
|
||||
"right_gripper",
|
||||
"left_waist",
|
||||
"left_shoulder",
|
||||
"left_elbow",
|
||||
"left_forearm_roll",
|
||||
"left_wrist_angle",
|
||||
"left_wrist_rotate",
|
||||
"left_gripper",
|
||||
]
|
||||
cameras = [
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
]
|
||||
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
if has_velocity:
|
||||
features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
if has_effort:
|
||||
features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
for cam in cameras:
|
||||
features[f"observation.images.{cam}"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 480, 640),
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
|
||||
if Path(LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=50,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
use_videos=dataset_config.use_videos,
|
||||
tolerance_s=dataset_config.tolerance_s,
|
||||
image_writer_processes=dataset_config.image_writer_processes,
|
||||
image_writer_threads=dataset_config.image_writer_threads,
|
||||
video_backend=dataset_config.video_backend,
|
||||
)
|
||||
|
||||
|
||||
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
# ignore depth channel, not currently handled
|
||||
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
|
||||
|
||||
def has_velocity(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/qvel" in ep
|
||||
|
||||
|
||||
def has_effort(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/effort" in ep
|
||||
|
||||
|
||||
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
||||
imgs_per_cam = {}
|
||||
for camera in cameras:
|
||||
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||
|
||||
if uncompressed:
|
||||
# load all images in RAM
|
||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
return imgs_per_cam
|
||||
|
||||
|
||||
def load_raw_episode_data(
|
||||
ep_path: Path,
|
||||
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
|
||||
effort = None
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
imgs_per_cam = load_raw_images_per_camera(
|
||||
ep,
|
||||
[
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
],
|
||||
)
|
||||
|
||||
return imgs_per_cam, state, action, velocity, effort
|
||||
|
||||
|
||||
def populate_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
hdf5_files: list[Path],
|
||||
task: str,
|
||||
episodes: list[int] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
if episodes is None:
|
||||
episodes = range(len(hdf5_files))
|
||||
|
||||
for ep_idx in tqdm.tqdm(episodes):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
|
||||
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||
num_frames = state.shape[0]
|
||||
|
||||
for i in range(num_frames):
|
||||
frame = {
|
||||
"observation.state": state[i],
|
||||
"action": action[i],
|
||||
}
|
||||
|
||||
for camera, img_array in imgs_per_cam.items():
|
||||
frame[f"observation.images.{camera}"] = img_array[i]
|
||||
|
||||
if velocity is not None:
|
||||
frame["observation.velocity"] = velocity[i]
|
||||
if effort is not None:
|
||||
frame["observation.effort"] = effort[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=task)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_aloha(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
raw_repo_id: str | None = None,
|
||||
task: str = "DEBUG",
|
||||
*,
|
||||
episodes: list[int] | None = None,
|
||||
push_to_hub: bool = True,
|
||||
is_mobile: bool = False,
|
||||
mode: Literal["video", "image"] = "image",
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
):
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
if raw_repo_id is None:
|
||||
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
||||
download_raw(raw_dir, repo_id=raw_repo_id)
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
|
||||
dataset = create_empty_dataset(
|
||||
repo_id,
|
||||
robot_type="mobile_aloha" if is_mobile else "aloha",
|
||||
mode=mode,
|
||||
has_effort=has_effort(hdf5_files),
|
||||
has_velocity=has_velocity(hdf5_files),
|
||||
dataset_config=dataset_config,
|
||||
)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
hdf5_files,
|
||||
task=task,
|
||||
episodes=episodes,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(port_aloha)
|
||||
57
policy/openpi-InternData-A1/examples/aloha_real/env.py
Normal file
57
policy/openpi-InternData-A1/examples/aloha_real/env.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List, Optional # noqa: UP035
|
||||
|
||||
import einops
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
|
||||
render_height: int = 224,
|
||||
render_width: int = 224,
|
||||
) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
for cam_name in obs["images"]:
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
|
||||
)
|
||||
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
|
||||
|
||||
return {
|
||||
"state": obs["qpos"],
|
||||
"images": obs["images"],
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["actions"])
|
||||
51
policy/openpi-InternData-A1/examples/aloha_real/main.py
Normal file
51
policy/openpi-InternData-A1/examples/aloha_real/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
num_episodes: int = 1
|
||||
max_episode_steps: int = 1000
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
||||
|
||||
metadata = ws_client_policy.get_server_metadata()
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=ws_client_policy,
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
num_episodes=args.num_episodes,
|
||||
max_episode_steps=args.max_episode_steps,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
176
policy/openpi-InternData-A1/examples/aloha_real/real_env.py
Normal file
176
policy/openpi-InternData-A1/examples/aloha_real/real_env.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
from typing import Optional, List
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
# This is the reset position that is used by the standard Aloha runtime.
|
||||
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
|
||||
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first close then open. Then change back to PWM mode
|
||||
|
||||
NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
|
||||
was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
|
||||
increase the frequency of motor faults.
|
||||
"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
|
||||
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy>=1.22.4,<2.0.0
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
policy/openpi-InternData-A1/examples/aloha_real/requirements.txt
Normal file
156
policy/openpi-InternData-A1/examples/aloha_real/requirements.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
policy/openpi-InternData-A1/examples/aloha_real/robot_utils.py
Normal file
275
policy/openpi-InternData-A1/examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
41
policy/openpi-InternData-A1/examples/aloha_sim/Dockerfile
Normal file
41
policy/openpi-InternData-A1/examples/aloha_sim/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for the Aloha simulation environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
||||
|
||||
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev
|
||||
ENV MUJOCO_GL=egl
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
||||
36
policy/openpi-InternData-A1/examples/aloha_sim/README.md
Normal file
36
policy/openpi-InternData-A1/examples/aloha_sim/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Run Aloha Sim
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_sim/.venv
|
||||
source examples/aloha_sim/.venv/bin/activate
|
||||
uv pip sync examples/aloha_sim/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the simulation
|
||||
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
||||
```
|
||||
|
||||
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env ALOHA_SIM
|
||||
```
|
||||
42
policy/openpi-InternData-A1/examples/aloha_sim/compose.yml
Normal file
42
policy/openpi-InternData-A1/examples/aloha_sim/compose.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_sim
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_sim/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
56
policy/openpi-InternData-A1/examples/aloha_sim/env.py
Normal file
56
policy/openpi-InternData-A1/examples/aloha_sim/env.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import gym_aloha # noqa: F401
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class AlohaSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self._gym = gymnasium.make(task, obs_type=obs_type)
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
img = gym_obs["pixels"]["top"]
|
||||
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
|
||||
return {
|
||||
"state": gym_obs["agent_pos"],
|
||||
"images": {"cam_high": img},
|
||||
}
|
||||
55
policy/openpi-InternData-A1/examples/aloha_sim/main.py
Normal file
55
policy/openpi-InternData-A1/examples/aloha_sim/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import env as _env
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import saver as _saver
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
|
||||
|
||||
task: str = "gym_aloha/AlohaTransferCube-v0"
|
||||
seed: int = 0
|
||||
|
||||
action_horizon: int = 10
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
display: bool = False
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaSimEnvironment(
|
||||
task=args.task,
|
||||
seed=args.seed,
|
||||
),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[
|
||||
_saver.VideoSaver(args.out_dir),
|
||||
],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
@@ -0,0 +1,8 @@
|
||||
gym-aloha
|
||||
imageio
|
||||
matplotlib
|
||||
msgpack
|
||||
numpy>=1.22.4,<2.0.0
|
||||
typing-extensions
|
||||
tyro
|
||||
websockets
|
||||
132
policy/openpi-InternData-A1/examples/aloha_sim/requirements.txt
Normal file
132
policy/openpi-InternData-A1/examples/aloha_sim/requirements.txt
Normal file
@@ -0,0 +1,132 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cloudpickle==3.1.0
|
||||
# via gymnasium
|
||||
contourpy==1.3.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
gym-aloha==0.1.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
gymnasium==1.0.0
|
||||
# via gym-aloha
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.36.1
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gym-aloha
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.9.3
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# gymnasium
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# scipy
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==11.0.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.0
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
requests==2.32.3
|
||||
# via dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
scipy==1.14.1
|
||||
# via dm-control
|
||||
setuptools==75.6.0
|
||||
# via
|
||||
# dm-control
|
||||
# imageio-ffmpeg
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.1
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gymnasium
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
40
policy/openpi-InternData-A1/examples/aloha_sim/saver.py
Normal file
40
policy/openpi-InternData-A1/examples/aloha_sim/saver.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoSaver(_subscriber.Subscriber):
|
||||
"""Saves episode data."""
|
||||
|
||||
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._out_dir = out_dir
|
||||
self._images: list[np.ndarray] = []
|
||||
self._subsample = subsample
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
self._images = []
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
im = observation["images"]["cam_high"] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
self._images.append(im)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
|
||||
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
|
||||
out_path = self._out_dir / f"out_{next_idx}.mp4"
|
||||
|
||||
logging.info(f"Saving video to {out_path}")
|
||||
imageio.mimwrite(
|
||||
out_path,
|
||||
[np.asarray(x) for x in self._images[:: self._subsample]],
|
||||
fps=50 // max(1, self._subsample),
|
||||
)
|
||||
212
policy/openpi-InternData-A1/examples/arx/action_stats.py
Normal file
212
policy/openpi-InternData-A1/examples/arx/action_stats.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from collections import deque
|
||||
from typing import List, Dict, Optional, Any, Sequence, Deque, Union
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def check_final(
|
||||
last_states: Union[Deque[Sequence[float]], Sequence[Sequence[float]], torch.Tensor],
|
||||
*,
|
||||
# 索引与初始状态
|
||||
arm_dofs: int = 6, # 左臂关节数(这里按你给的 6)
|
||||
gripper_index: int = -1, # 夹爪在向量中的索引(默认最后一维)
|
||||
mean_initial_arm_state: Optional[Sequence[float]] = (0.0107, 0.0527, 0.0463, -0.0415, 0.0187, 0.0108),
|
||||
mean_initial_gripper_state: float = 4.8438, # 目前不参与判定,保留以便后续扩展
|
||||
|
||||
# 判定阈值(角度阈值用“度”直观易调,内部会转换为弧度)
|
||||
stability_window: int = 5, # 最近多少帧用于判“没有太大变化”
|
||||
per_joint_range_deg: float = 2.0, # 窗口内每个关节的最大幅度(max-min)阈值(度)
|
||||
mean_speed_deg: float = 0.5, # 邻帧关节差的平均 L2(每步)阈值(度/步)
|
||||
min_change_from_initial_deg: float = 15.0, # 末帧相对初始的“至少变化量”(L2,度)
|
||||
gripper_closed_thresh: float = 0.8, # 夹爪关闭阈值(数值越小说明越闭合)
|
||||
) -> bool:
|
||||
"""
|
||||
返回 True 表示“到位”:(1) 最近窗口内姿态变化不大 & (2) 夹爪关闭 & (3) 末帧与初始相差足够大。
|
||||
所有角度的阈值以“度”给出,这里会自动转弧度再比较。
|
||||
"""
|
||||
# --- 数据整理为 (N, D) tensor ---
|
||||
if isinstance(last_states, torch.Tensor):
|
||||
states = last_states
|
||||
else:
|
||||
states = torch.as_tensor(list(last_states), dtype=torch.float32)
|
||||
|
||||
if states.ndim != 2:
|
||||
raise ValueError(f"last_states should be 2D, got shape {tuple(states.shape)}")
|
||||
N, D = states.shape
|
||||
if D < arm_dofs:
|
||||
raise ValueError(f"Expected at least {arm_dofs} dims for arm + gripper, got {D}")
|
||||
if N < 2:
|
||||
return False # 样本太少,无法判定稳定
|
||||
|
||||
# 取最近窗口
|
||||
w = min(N, stability_window)
|
||||
window = states[-w:] # (w, D)
|
||||
arm = window[:, :arm_dofs] # (w, 6)
|
||||
last_arm = arm[-1] # (6,)
|
||||
last_gripper = float(window[-1, gripper_index])
|
||||
|
||||
# --- 1) 最近 w 帧“没有太大变化” ---
|
||||
# 两个指标:每关节range(max-min)要小、相邻帧的平均“速度”要小
|
||||
deg2rad = torch.pi / 180.0
|
||||
range_tol = per_joint_range_deg * deg2rad
|
||||
speed_tol = mean_speed_deg * deg2rad
|
||||
|
||||
ranges = arm.max(dim=0).values - arm.min(dim=0).values # (6,)
|
||||
max_range = float(ranges.abs().max()) # 标量
|
||||
diffs = arm[1:] - arm[:-1] # (w-1, 6)
|
||||
mean_speed = float(torch.linalg.norm(diffs, dim=1).mean()) # 每步的平均 L2
|
||||
|
||||
stable = (max_range <= range_tol) and (mean_speed <= speed_tol)
|
||||
|
||||
# --- 2) 夹爪关闭 ---
|
||||
gripper_closed = (last_gripper < gripper_closed_thresh)
|
||||
|
||||
# --- 3) 末帧与“初始”差距要大 ---
|
||||
init = torch.as_tensor(mean_initial_arm_state, dtype=last_arm.dtype, device=last_arm.device)
|
||||
if init.numel() != arm_dofs:
|
||||
raise ValueError(f"mean_initial_arm_state length {init.numel()} != arm_dofs {arm_dofs}")
|
||||
dist_from_init = float(torch.linalg.norm(last_arm - init))
|
||||
far_from_init = (dist_from_init >= (min_change_from_initial_deg * deg2rad))
|
||||
|
||||
# 组合判定
|
||||
return bool(stable and gripper_closed and far_from_init)
|
||||
# return bool(gripper_closed and far_from_init)
|
||||
|
||||
|
||||
def get_last_frames(ds: LeRobotDataset, include_images: bool = False, keys=None):
|
||||
"""
|
||||
Quickly fetch the last frame of each episode in a LeRobotDataset.
|
||||
- include_images=False: Return only scalar/vector fields from parquet (faster, no video decoding).
|
||||
- include_images=True : Additionally decode the corresponding image/video frame for the last frame.
|
||||
- keys: Limit the set of columns to retrieve (default: all non-image/video fields + timestamp, etc.).
|
||||
Returns: list[dict], where each element contains the last frame info of one episode.
|
||||
"""
|
||||
# 1) Compute the global index of the last row for each episode.
|
||||
# ds.episode_data_index['to'] is the exclusive end index, so last frame = to - 1.
|
||||
end_idxs = (ds.episode_data_index["to"] - 1).tolist()
|
||||
|
||||
# 2) Determine which columns to load.
|
||||
# By default, exclude video/image columns to avoid triggering slow video decoding.
|
||||
if keys is None:
|
||||
non_media_keys = [k for k, ft in ds.features.items() if ft["dtype"] not in ("image", "video")]
|
||||
keys = list(set(non_media_keys + ["timestamp", "episode_index", "task_index"]))
|
||||
|
||||
# 3) Select all last-frame rows at once (does not call __getitem__, so no video decoding is triggered).
|
||||
last_rows = ds.hf_dataset.select(end_idxs)
|
||||
|
||||
# 4) Build a dictionary of tensors for each requested key.
|
||||
out = []
|
||||
col = {k: last_rows[k] for k in keys}
|
||||
|
||||
# Convert lists of tensors into stacked tensors for easier indexing.
|
||||
for k, v in col.items():
|
||||
# datasets.arrow_dataset.Column is the HuggingFace internal type for columns.
|
||||
if isinstance(v, datasets.arrow_dataset.Column) and len(v) > 0 and hasattr(v[0], "shape"):
|
||||
col[k] = torch.stack(v[:])
|
||||
|
||||
# Iterate through each episode’s last frame and build a dict with its values.
|
||||
for i, ep_end in enumerate(end_idxs):
|
||||
item = {}
|
||||
for k in keys:
|
||||
val = col[k][i]
|
||||
# Unpack 0-dimensional tensors into Python scalars.
|
||||
if torch.is_tensor(val) and val.ndim == 0:
|
||||
val = val.item()
|
||||
item[k] = val
|
||||
|
||||
# Map task_index back to the human-readable task string.
|
||||
if "task_index" in item:
|
||||
item["task"] = ds.meta.tasks[int(item["task_index"])]
|
||||
out.append(item)
|
||||
|
||||
# 5) Optionally decode the actual image/video frame for each last timestamp.
|
||||
if include_images and len(ds.meta.video_keys) > 0:
|
||||
for i, ep_end in enumerate(end_idxs):
|
||||
ep_idx = int(out[i]["episode_index"])
|
||||
ts = float(out[i]["timestamp"])
|
||||
# Prepare a query dictionary: one timestamp per camera key.
|
||||
query_ts = {k: [ts] for k in ds.meta.video_keys}
|
||||
# Decode video frames at the specified timestamps for this episode.
|
||||
frames = ds._query_videos(query_ts, ep_idx)
|
||||
# Attach the decoded frame tensors to the output dictionary.
|
||||
for k, v in frames.items():
|
||||
out[i][k] = v
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize your dataset (replace with your repo ID or local path).
|
||||
ds = LeRobotDataset(repo_id="arx_lift2/pick_parcel_20250915")
|
||||
|
||||
# Retrieve metadata only (timestamps, states, actions, tasks) without decoding video.
|
||||
last_infos = get_last_frames(ds, include_images=False)
|
||||
|
||||
# Stack all 'observation.state' vectors into a single tensor for further processing.
|
||||
states = torch.stack([info['observation.state'] for info in last_infos])
|
||||
# Extract the left-arm joint states (first 7 values of each state vector).
|
||||
left_arm_states = states[:, 0:7]
|
||||
mean_state = torch.mean(left_arm_states, dim=0)
|
||||
std_state = torch.std(left_arm_states, dim=0)
|
||||
|
||||
# Print the collected metadata for verification.
|
||||
print(last_infos)
|
||||
|
||||
# --- Run check_final per episode using the last <=50 states ---
|
||||
|
||||
EP_ARM_DOFS = 6 # number of left-arm joints we use in check_final
|
||||
GRIPPER_COL_FULL = -1 # gripper is the last element in the full state vector
|
||||
STABILITY_WINDOW = 120 # must be consistent with check_final's default
|
||||
|
||||
# Determine which episodes to iterate
|
||||
episode_indices = ds.episodes if ds.episodes is not None else sorted(ds.meta.episodes.keys())
|
||||
|
||||
episode_flags = {}
|
||||
num_true, num_false = 0, 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
# Global index range [from_idx, to_idx) for this episode
|
||||
from_idx = int(ds.episode_data_index["from"][ep_idx])
|
||||
to_idx = int(ds.episode_data_index["to"][ep_idx])
|
||||
|
||||
if to_idx - from_idx <= 0:
|
||||
episode_flags[ep_idx] = False
|
||||
num_false += 1
|
||||
continue
|
||||
|
||||
# Take the last <= STABILITY_WINDOW frames from this episode
|
||||
idxs = list(range(max(from_idx, to_idx - STABILITY_WINDOW), to_idx))
|
||||
rows = ds.hf_dataset.select(idxs)
|
||||
|
||||
# Collect full "observation.state" (shape ~ [W, S])
|
||||
s_col = rows["observation.state"]
|
||||
if isinstance(s_col, datasets.arrow_dataset.Column):
|
||||
S = torch.stack(s_col[:]) # Column -> list[tensor] -> stack
|
||||
else:
|
||||
S = torch.stack(s_col) # already a list[tensor]
|
||||
|
||||
# Build the 7D small state per frame: first 6 joints + gripper
|
||||
# (Assumes the gripper signal is at the last position of the full state vector)
|
||||
small_states = torch.cat([S[:, :EP_ARM_DOFS], S[:, EP_ARM_DOFS:EP_ARM_DOFS+1]], dim=1)
|
||||
|
||||
# Run your stopping logic
|
||||
ok = check_final(
|
||||
small_states,
|
||||
arm_dofs=EP_ARM_DOFS,
|
||||
gripper_index=-1,
|
||||
stability_window=STABILITY_WINDOW,
|
||||
)
|
||||
episode_flags[ep_idx] = bool(ok)
|
||||
num_true += int(ok)
|
||||
num_false += int(not ok)
|
||||
|
||||
# Summary
|
||||
total_eps = len(episode_indices)
|
||||
print(f"[check_final] passed: {num_true} / {total_eps} ({(num_true/max(total_eps,1)):.1%})")
|
||||
|
||||
# List some failed episodes for quick inspection
|
||||
failed_eps = [e for e, passed in episode_flags.items() if not passed]
|
||||
print("Failed episode indices (first 20):", failed_eps[:20])
|
||||
|
||||
88
policy/openpi-InternData-A1/examples/arx/extract_frame.py
Normal file
88
policy/openpi-InternData-A1/examples/arx/extract_frame.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
def extract_last_frame_from_videos(root_dir, output_dir, xx_last_frame=1):
|
||||
"""
|
||||
遍历目录,找到所有images.rgb.hand_right视频文件,提取最后一帧并保存
|
||||
"""
|
||||
# 查找所有mp4视频文件
|
||||
video_files = []
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
for file in files:
|
||||
|
||||
if file.endswith('.mp4') and 'observation/head' in root:
|
||||
video_files.append(os.path.join(root, file))
|
||||
|
||||
print(f"找到 {len(video_files)} 个视频文件")
|
||||
|
||||
# 处理每个视频文件
|
||||
for video_path in tqdm(video_files):
|
||||
try:
|
||||
# 提取set名称和episode名称
|
||||
path_parts = Path(video_path).parts
|
||||
set_name = None
|
||||
episode_name = None
|
||||
for part in path_parts:
|
||||
if part.startswith('set'):
|
||||
set_name = part
|
||||
if part.startswith('000'):
|
||||
episode_name = part.replace('.mp4', '')
|
||||
|
||||
if not set_name or not episode_name:
|
||||
print(f"无法从路径中提取set和episode信息: {video_path}")
|
||||
continue
|
||||
|
||||
# 生成输出文件名
|
||||
output_filename = f"{set_name}_{episode_name}.jpg"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# 打开视频文件
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not cap.isOpened():
|
||||
print(f"无法打开视频: {video_path}")
|
||||
continue
|
||||
|
||||
# 获取总帧数
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
if total_frames == 0:
|
||||
print(f"视频没有帧: {video_path}")
|
||||
cap.release()
|
||||
continue
|
||||
|
||||
# 跳转到最后一帧
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - xx_last_frame)
|
||||
ret, frame = cap.read()
|
||||
|
||||
if ret:
|
||||
# 保存最后一帧
|
||||
cv2.imwrite(output_path, frame)
|
||||
print(f"已保存:\n {output_path}")
|
||||
else:
|
||||
print(f"无法读取最后一帧: {video_path}")
|
||||
|
||||
# 释放资源
|
||||
cap.release()
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理视频时出错 {video_path}: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 指定要遍历的根目录
|
||||
root_directory = "/home/caijunhao/h-ceph/InternData-A1-raw/arx_lift2/Pick_the_industrial_components_from_the_conveyor" # 当前目录,您可以修改为您的目录路径
|
||||
output_path = 'data/Pick_the_industrial_components_from_the_conveyor/'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
sub_list = os.listdir(root_directory)
|
||||
exclude_list = []
|
||||
# exclude_list = [f"{i}" for i in range(16)] + [f"{i}" for i in range(26, 29)]
|
||||
xx_last_frame = 1
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
for sub in tqdm(sub_list):
|
||||
if sub.split('-')[1].split('_')[0] in exclude_list:
|
||||
continue
|
||||
# print("os.path.join([root_directory, sub])\n", os.path.join(root_directory, sub))
|
||||
extract_last_frame_from_videos(os.path.join(root_directory, sub), output_path, xx_last_frame=xx_last_frame)
|
||||
print("处理完成!")
|
||||
670
policy/openpi-InternData-A1/examples/arx/lmdb2lerobot_arx.py
Normal file
670
policy/openpi-InternData-A1/examples/arx/lmdb2lerobot_arx.py
Normal file
@@ -0,0 +1,670 @@
|
||||
# source /fs-computility/efm/liyang/miniconda3/etc/profile.d/conda.sh
|
||||
# conda activate act
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import gc
|
||||
import shutil
|
||||
from concurrent.futures import ALL_COMPLETED, ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
import torchvision
|
||||
import cv2
|
||||
import h5py
|
||||
import lmdb
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
from PIL import Image
|
||||
from scipy.spatial.transform import Rotation
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import pdb
|
||||
import os
|
||||
import imageio # imageio-ffmpeg
|
||||
from lerobot.common.datasets.compute_stats import auto_downsample_height_width, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import check_timestamps_sync, get_episode_data_index, validate_episode_buffer
|
||||
import time
|
||||
# import ray
|
||||
# from ray.runtime_env import RuntimeEnv
|
||||
|
||||
"""
|
||||
Store both camera image and robot state as a combined observation.
|
||||
Args:
|
||||
observation: images(camera), states (robot state)
|
||||
actions: joint, gripper, ee_pose
|
||||
"""
|
||||
FEATURES = {
|
||||
"images.rgb.head": {
|
||||
"dtype": "video",
|
||||
"shape": (368, 640, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"images.rgb.hand_left": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"images.rgb.hand_right": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
# "states.left_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
|
||||
# },
|
||||
# "states.left_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["left_gripper_0",],
|
||||
# },
|
||||
# "states.right_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
|
||||
# },
|
||||
# "states.right_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["right_gripper_0",],
|
||||
# },
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,),
|
||||
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
|
||||
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,),
|
||||
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
|
||||
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
|
||||
},
|
||||
# "actions.left_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
|
||||
# },
|
||||
# "actions.left_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["left_gripper_0",],
|
||||
# },
|
||||
# "actions.right_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
|
||||
# },
|
||||
# "actions.right_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["right_gripper_0", ],
|
||||
# },
|
||||
|
||||
}
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
def filter_forbidden_frames(state_dict, position_threshold=0.001, velocity_threshold=0.005):
|
||||
"""
|
||||
过滤禁止的帧,基于位置和速度阈值
|
||||
|
||||
参数:
|
||||
- state_dict: 形状为 (n, 14) 的状态数组
|
||||
- position_threshold: 位置变化的阈值
|
||||
- velocity_threshold: 速度变化的阈值
|
||||
|
||||
返回:
|
||||
- valid_mask: 布尔数组,True表示有效帧
|
||||
"""
|
||||
# 排除夹爪列(第6和第13列,索引从0开始)
|
||||
qpos_columns = [i for i in range(14)]
|
||||
qpos_data = state_dict[:, qpos_columns]
|
||||
|
||||
n_frames = len(state_dict)
|
||||
valid_mask = np.ones(n_frames, dtype=bool)
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# 计算帧间差异(速度)
|
||||
if n_frames > 1:
|
||||
|
||||
diff_sum = np.sum(np.abs(np.diff(qpos_data, axis=0)), axis=1)
|
||||
# sorted_indices = np.argsort(diff_sum)[::-1]
|
||||
# sorted_abs_sums = diff_sum[sorted_indices]
|
||||
|
||||
# velocities = np.diff(qpos_data, axis=0)
|
||||
# 检查速度是否超过阈值
|
||||
for i in range(n_frames - 1):
|
||||
if np.any(np.abs(diff_sum[i]) > position_threshold):
|
||||
valid_mask[i] = True # 有运动,有效帧
|
||||
else:
|
||||
valid_mask[i] = False # 静止,可能是禁止帧
|
||||
valid_mask[i] = True
|
||||
return valid_mask
|
||||
|
||||
def statistical_filter(state_dict, std_multiplier=2.0):
|
||||
"""
|
||||
使用统计方法检测异常(禁止)帧
|
||||
"""
|
||||
# 排除夹爪列
|
||||
qpos_columns = [i for i in range(14) if i not in [6, 13]]
|
||||
qpos_data = state_dict[:, qpos_columns]
|
||||
|
||||
# 计算每列的均值和标准差
|
||||
means = np.mean(qpos_data, axis=0)
|
||||
stds = np.std(qpos_data, axis=0)
|
||||
|
||||
# 创建有效掩码
|
||||
valid_mask = np.ones(len(state_dict), dtype=bool)
|
||||
|
||||
for i in range(len(state_dict)):
|
||||
# 检查每个关节位置是否在合理范围内
|
||||
deviations = np.abs(qpos_data[i] - means)
|
||||
if np.any(deviations > std_multiplier * stds):
|
||||
valid_mask[i] = False # 异常帧
|
||||
|
||||
return valid_mask
|
||||
|
||||
|
||||
class ARXDataset(LeRobotDataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
episodes=episodes,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=tolerance_s,
|
||||
download_videos=download_videos,
|
||||
local_files_only=local_files_only,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None, videos: dict | None = None) -> None:
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
for key, ft in self.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
episode_buffer[key] = str(video_path) # PosixPath -> str
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(videos[key], video_path)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
if not episode_data:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
for name in frame:
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
features = {key: value for key, value in self.features.items() if key in self.hf_features}
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
for key in frame:
|
||||
if key == "task":
|
||||
self.episode_buffer["task"].append(frame["task"])
|
||||
continue
|
||||
if key not in self.features:
|
||||
print("key ", key)
|
||||
raise ValueError(f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'.")
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
# def crop_resize_no_padding(image, target_size=(480, 640)):
|
||||
# """
|
||||
# Crop and scale to target size (no padding)
|
||||
# :param image: input image (NumPy array)
|
||||
# :param target_size: target size (height, width)
|
||||
# :return: processed image
|
||||
# """
|
||||
# h, w = image.shape[:2]
|
||||
# target_h, target_w = target_size
|
||||
# target_ratio = target_w / target_h # Target aspect ratio (e.g. 640/480=1.333)
|
||||
|
||||
# # the original image aspect ratio and cropping direction
|
||||
# if w / h > target_ratio: # Original image is wider → crop width
|
||||
# crop_w = int(h * target_ratio) # Calculate crop width based on target aspect ratio
|
||||
# crop_h = h
|
||||
# start_x = (w - crop_w) // 2 # Horizontal center starting point
|
||||
# start_y = 0
|
||||
# else: # Original image is higher → crop height
|
||||
# crop_h = int(w / target_ratio) # Calculate clipping height according to target aspect ratio
|
||||
# crop_w = w
|
||||
# start_x = 0
|
||||
# start_y = (h - crop_h) // 2 # Vertical center starting point
|
||||
|
||||
# # Perform centered cropping (to prevent out-of-bounds)
|
||||
# start_x, start_y = max(0, start_x), max(0, start_y)
|
||||
# end_x, end_y = min(w, start_x + crop_w), min(h, start_y + crop_h)
|
||||
# cropped = image[start_y:end_y, start_x:end_x]
|
||||
|
||||
# # Resize to target size (bilinear interpolation)
|
||||
# resized = cv2.resize(cropped, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
||||
# return resized
|
||||
|
||||
|
||||
def load_lmdb_data(episode_path: Path, sava_path: Path, fps_factor: int, target_fps: int) -> Optional[Dict]:
|
||||
def load_image(txn, key):
|
||||
raw = txn.get(key)
|
||||
data = pickle.loads(raw)
|
||||
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||
# Convert to RGB if necessary
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
# image = crop_resize_no_padding(image, target_size=(480, 640))
|
||||
return image
|
||||
try:
|
||||
env = lmdb.open(
|
||||
str(episode_path / "lmdb"),
|
||||
readonly=True,
|
||||
lock=False,
|
||||
max_readers=128,
|
||||
readahead=False
|
||||
)
|
||||
with env.begin(write=False) as txn:
|
||||
keys = [k for k, _ in txn.cursor()]
|
||||
|
||||
image_keys = sorted([k for k in keys if b'head' in k])
|
||||
if not image_keys:
|
||||
return None
|
||||
|
||||
all_qpos = pickle.loads(txn.get(b'/observations/qpos'))
|
||||
|
||||
if np.isscalar(all_qpos):
|
||||
total_steps = len(image_keys)
|
||||
all_qpos = [all_qpos] * total_steps
|
||||
else:
|
||||
total_steps = len(all_qpos)
|
||||
all_qpos = np.stack(all_qpos)
|
||||
state_action_dict = {}
|
||||
state_action_dict["states.left_joint.position"] = all_qpos[:, :6]
|
||||
state_action_dict["states.left_gripper.position"] = all_qpos[:, 6][:, None] # np.expand_dims(all_qpos[:, 6], axis=1)
|
||||
state_action_dict["states.right_joint.position"] = all_qpos[:, 7:13]
|
||||
state_action_dict["states.right_gripper.position"] = all_qpos[:, 13][:, None] # np.expand_dims(all_qpos[:, 13], axis=1)
|
||||
# state_keys = list(state_action_dict.keys())
|
||||
# for k in state_keys:
|
||||
# state_action_dict[k.replace("states", "actions")] = np.concatenate([state_action_dict[k][1:, :], state_action_dict[k][-1, :][None,:]], axis=0)
|
||||
|
||||
|
||||
# action_dict = {}
|
||||
# action_dict["actions.left_joint.position"] = np.concatenate([state_dict["states.left_joint.position"][1:, :], state_dict["states.left_joint.position"][-1, :][None,:]], axis=0)
|
||||
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
|
||||
# action_dict["actions.right_joint.position"] = state_dict["states.right_joint.position"][1:, :]
|
||||
# action_dict["actions.right_gripper.position"] = state_dict["states.right_gripper.position"][1:, :]
|
||||
|
||||
action_dict = {}
|
||||
|
||||
action_dict["action"] = np.concatenate([all_qpos[1:,], all_qpos[-1,].reshape(-1, 14)], axis=0)
|
||||
state_dict = {}
|
||||
state_dict["observation.state"] = all_qpos
|
||||
mask1 = filter_forbidden_frames(state_dict["observation.state"])
|
||||
# state_dict["observation.state"] = state_dict["observation.state"][mask1]
|
||||
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
|
||||
# action_dict["actions.right_arm.position"] = np.concatenate([state_action_dict["states.right_joint.position"][1:, :], state_action_dict["states.right_joint.position"][-1, :][None,:]], axis=0)
|
||||
# action_dict["actions.left_arm.position"] = state_dict["states.right_gripper.position"][1:, :]
|
||||
|
||||
assert total_steps == len(image_keys), "qpos length mismatch"
|
||||
selected_steps = [step for step in range(total_steps) if step % fps_factor == 0 and mask1[step]]
|
||||
frames = []
|
||||
image_observations = {
|
||||
"images.rgb.head": [],
|
||||
"images.rgb.hand_left": [],
|
||||
"images.rgb.hand_right": []
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for step_index, step in enumerate(selected_steps):
|
||||
step_str = f"{step:04d}"
|
||||
head_key = f"observation/head/color_image/{step_str}".encode()
|
||||
left_key = f"observation/left_wrist/color_image/{step_str}".encode()
|
||||
right_key = f"observation/right_wrist/color_image/{step_str}".encode()
|
||||
if not (head_key in keys and left_key in keys and right_key in keys):
|
||||
continue
|
||||
# state = all_qpos[step]
|
||||
# if step_index < len(selected_steps) - 1:
|
||||
# action = all_qpos[selected_steps[step_index + 1]]
|
||||
# else:
|
||||
# action = state
|
||||
data_dict = {}
|
||||
# for key, value in state_action_dict.items():
|
||||
# data_dict[key] = value[step]
|
||||
data_dict['action'] = action_dict["action"][step]
|
||||
data_dict["task"] = " ".join(episode_path.parent.parent.name.split("_"))
|
||||
data_dict['observation.state'] = state_dict["observation.state"][step]
|
||||
# frames.append({
|
||||
# "observation.states.joint.position": state,
|
||||
# "actions.joint.position": action,
|
||||
# "task": task_name,
|
||||
# })
|
||||
frames.append(data_dict)
|
||||
image_observations["images.rgb.head"].append(load_image(txn, head_key))
|
||||
image_observations["images.rgb.hand_left"].append(load_image(txn, left_key))
|
||||
image_observations["images.rgb.hand_right"].append(load_image(txn, right_key))
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"load image_observations of {episode_path}")
|
||||
env.close()
|
||||
if not frames:
|
||||
return None
|
||||
os.makedirs(sava_path, exist_ok=True)
|
||||
os.makedirs(sava_path/episode_path.name, exist_ok=True)
|
||||
imageio.mimsave(sava_path/episode_path.name/'head.mp4', image_observations["images.rgb.head"], fps=target_fps)
|
||||
imageio.mimsave(sava_path/episode_path.name/'hand_left.mp4', image_observations["images.rgb.hand_left"], fps=target_fps)
|
||||
imageio.mimsave(sava_path/episode_path.name/'hand_right.mp4', image_observations["images.rgb.hand_right"], fps=target_fps)
|
||||
print(f"imageio.mimsave time taken of {episode_path}")
|
||||
|
||||
return {
|
||||
"frames": frames,
|
||||
"videos": {
|
||||
"images.rgb.head": sava_path/episode_path.name/"head.mp4",
|
||||
"images.rgb.hand_left": sava_path/episode_path.name/"hand_left.mp4",
|
||||
"images.rgb.hand_right": sava_path/episode_path.name/"hand_right.mp4",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load LMDB data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_all_tasks(src_path: Path, output_path: Path) -> Tuple[Path, Path]:
|
||||
src_dirs = sorted(list(src_path.glob("*"))) # "set*-*_collector*_datatime" as the conversion unit
|
||||
|
||||
save_dirs = [output_path/_dir.parent.name/_dir.name for _dir in src_dirs]
|
||||
tasks_tuples = zip(src_dirs, save_dirs)
|
||||
for task in tasks_tuples:
|
||||
yield task
|
||||
|
||||
def compute_episode_stats(episode_data: Dict[str, List[str] | np.ndarray], features: Dict) -> Dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
return ep_stats
|
||||
|
||||
def sample_images(input):
|
||||
if type(input) is str:
|
||||
video_path = input
|
||||
reader = torchvision.io.VideoReader(video_path, stream="video")
|
||||
frames = [frame["data"] for frame in reader]
|
||||
frames_array = torch.stack(frames).numpy() # Shape: [T, C, H, W]
|
||||
sampled_indices = sample_indices(len(frames_array))
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
img = frames_array[idx]
|
||||
img = auto_downsample_height_width(img)
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
images[i] = img
|
||||
elif type(input) is np.ndarray:
|
||||
frames_array = input[:, None, :, :] # Shape: [T, C, H, W]
|
||||
sampled_indices = sample_indices(len(frames_array))
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
img = frames_array[idx]
|
||||
img = auto_downsample_height_width(img)
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
images[i] = img
|
||||
return images
|
||||
|
||||
|
||||
def load_local_dataset(episode_path: str, save_path:str, origin_fps=30, target_fps=30):
|
||||
fps_factor = origin_fps // target_fps
|
||||
# print(f"fps downsample factor: {fps_factor}")
|
||||
# logging.info(f"fps downsample factor: {fps_factor}")
|
||||
# for format_str in [f"{episode_id:07d}", f"{episode_id:06d}", str(episode_id)]:
|
||||
# episode_path = Path(src_path) / format_str
|
||||
# save_path = Path(save_path) / format_str
|
||||
# if episode_path.exists():
|
||||
# break
|
||||
# else:
|
||||
# logging.warning(f"Episode directory not found for ID {episode_id}")
|
||||
# return None, None
|
||||
episode_path = Path(episode_path)
|
||||
if not episode_path.exists():
|
||||
logging.warning(f"{episode_path} does not exist")
|
||||
return None, None
|
||||
|
||||
if not (episode_path / "lmdb/data.mdb").exists():
|
||||
logging.warning(f"LMDB data not found for episode {episode_path}")
|
||||
return None, None
|
||||
|
||||
raw_dataset = load_lmdb_data(episode_path, save_path, fps_factor, target_fps)
|
||||
if raw_dataset is None:
|
||||
return None, None
|
||||
frames = raw_dataset["frames"] # states, actions, task
|
||||
|
||||
videos = raw_dataset["videos"] # image paths
|
||||
## check the frames
|
||||
for camera_name, video_path in videos.items():
|
||||
if not os.path.exists(video_path):
|
||||
logging.error(f"Video file {video_path} does not exist.")
|
||||
print(f"Camera {camera_name} Video file {video_path} does not exist.")
|
||||
return None, None
|
||||
return frames, videos
|
||||
|
||||
|
||||
def save_as_lerobot_dataset(task: tuple[Path, Path], repo_id, num_threads, debug, origin_fps=30, target_fps=30, robot_type="piper", delete_downsampled_videos=True):
|
||||
src_path, save_path = task
|
||||
print(f"**Processing collected** {src_path}")
|
||||
print(f"**saving to** {save_path}")
|
||||
if save_path.exists():
|
||||
# print(f"Output directory {save_path} already exists. Deleting it.")
|
||||
# logging.warning(f"Output directory {save_path} already exists. Deleting it.")
|
||||
# shutil.rmtree(save_path)
|
||||
print(f"Output directory {save_path} already exists.")
|
||||
return
|
||||
|
||||
dataset = ARXDataset.create(
|
||||
repo_id=f"{repo_id}",
|
||||
root=save_path,
|
||||
fps=target_fps,
|
||||
robot_type=robot_type,
|
||||
features=FEATURES,
|
||||
)
|
||||
all_episode_paths = sorted([f.as_posix() for f in src_path.glob(f"*") if f.is_dir()])
|
||||
# all_subdir_eids = [int(Path(path).name) for path in all_subdir]
|
||||
if debug:
|
||||
for i in range(1):
|
||||
# pdb.set_trace()
|
||||
frames, videos = load_local_dataset(episode_path=all_episode_paths[i], save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
|
||||
for frame_data in frames:
|
||||
dataset.add_frame(frame_data)
|
||||
dataset.save_episode(videos=videos)
|
||||
if delete_downsampled_videos:
|
||||
for _, video_path in videos.items():
|
||||
parent_dir = os.path.dirname(video_path)
|
||||
try:
|
||||
shutil.rmtree(parent_dir)
|
||||
# os.remove(video_path)
|
||||
# print(f"Successfully deleted: {parent_dir}")
|
||||
print(f"Successfully deleted: {video_path}")
|
||||
except Exception as e:
|
||||
pass # Handle the case where the directory might not exist or is already deleted
|
||||
else:
|
||||
for batch_index in range(len(all_episode_paths)//num_threads+1):
|
||||
batch_episode_paths = all_episode_paths[batch_index*num_threads:(batch_index+1)*num_threads]
|
||||
if len(batch_episode_paths) == 0:
|
||||
continue
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
for episode_path in batch_episode_paths:
|
||||
print("starting to process episode: ", episode_path)
|
||||
futures.append(
|
||||
executor.submit(load_local_dataset, episode_path=episode_path, save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
|
||||
)
|
||||
for raw_dataset in as_completed(futures):
|
||||
frames, videos = raw_dataset.result()
|
||||
if frames is None or videos is None:
|
||||
print(f"Skipping episode {episode_path} due to missing data.")
|
||||
continue
|
||||
for frame_data in frames:
|
||||
dataset.add_frame(frame_data)
|
||||
dataset.save_episode(videos=videos)
|
||||
gc.collect()
|
||||
print(f"finishing processed {videos}")
|
||||
if delete_downsampled_videos:
|
||||
for _, video_path in videos.items():
|
||||
# Get the parent directory of the video
|
||||
parent_dir = os.path.dirname(video_path)
|
||||
try:
|
||||
shutil.rmtree(parent_dir)
|
||||
print(f"Successfully deleted: {parent_dir}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def main(src_path, save_path, repo_id, num_threads=60, debug=False, origin_fps=30, target_fps=30):
|
||||
logging.info("Scanning for episodes...")
|
||||
tasks = get_all_tasks(src_path, save_path)
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
if debug:
|
||||
task = next(tasks)
|
||||
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
|
||||
else:
|
||||
for task in tasks:
|
||||
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert collected data from Piper to Lerobot format.")
|
||||
parser.add_argument(
|
||||
"--src_path",
|
||||
type=str,
|
||||
# required=False,
|
||||
default="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/",
|
||||
help="Path to the input file containing collected data in Piper format.",
|
||||
#help="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Make_a_beef_sandwich",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
# required=False,
|
||||
default="/fs-computility/efm/shared/datasets/myData-A1/real/lerobot_v2_1/agilex_split_aloha/",
|
||||
help="Path to the output file where the converted Lerobot format will be saved.",
|
||||
#help="Path to the output file where the converted Lerobot format will be saved.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Run in debug mode with limited episodes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of threads per process",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--task_name",
|
||||
# type=str,
|
||||
# required=True,
|
||||
# default="Pick_up_the_marker_and_put_it_into_the_pen_holder",
|
||||
# help="Name of the task to be processed. Default is 'Pick_up_the_marker_and_put_it_into_the_pen_holder'.",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
required=True,
|
||||
# default="SplitAloha_20250714",
|
||||
help="identifier for the dataset repository.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--origin_fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames per second for the obervation video. Default is 30.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames per second for the downsample video. Default is 30.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert int(args.origin_fps) % int(args.target_fps) == 0, "origin_fps must be an integer multiple of target_fps"
|
||||
start_time = time.time()
|
||||
main(
|
||||
src_path=Path(args.src_path),
|
||||
save_path=Path(args.save_path),
|
||||
repo_id=args.repo_id,
|
||||
num_threads=args.num_threads,
|
||||
debug=args.debug,
|
||||
origin_fps=args.origin_fps,
|
||||
target_fps=args.target_fps
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"Total time taken: {elapsed_time:.2f} seconds")
|
||||
# --target_fps 10
|
||||
# --src_path /fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Put_the_bananas_in_the_basket
|
||||
# --save_path /mnt/shared-storage-user/internvla/Users/liyang/data/processed_data/arx_lift2
|
||||
1693
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data.py
Normal file
1693
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data.py
Normal file
File diff suppressed because it is too large
Load Diff
1509
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data_v2.py
Normal file
1509
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,587 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
|
||||
|
||||
This script loads a JAX model checkpoint using orbax and can either:
|
||||
1. Print out all the parameter keys in a hierarchical structure for inspection
|
||||
2. Convert the JAX model to PyTorch format using our PI0Pytorch model
|
||||
|
||||
Usage:
|
||||
# Just inspect keys:
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
||||
|
||||
# Convert to PyTorch:
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
||||
|
||||
Example:
|
||||
# pi0_droid
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
|
||||
|
||||
# pi0_aloha_sim
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
||||
|
||||
# pi05_droid
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
from typing import Literal
|
||||
|
||||
from flax.nnx import traversals
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import safetensors
|
||||
import torch
|
||||
import tyro
|
||||
|
||||
import openpi.models.gemma
|
||||
import openpi.models.model
|
||||
import openpi.models.pi0_config
|
||||
import openpi.models_pytorch.pi0_pytorch
|
||||
from openpi.training import utils
|
||||
import openpi.training.config as _config
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
"""Convert PaliGemma JAX parameters to PyTorch format."""
|
||||
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
||||
|
||||
# patch embeddings
|
||||
jax_key = f"img/embedding/kernel{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
|
||||
|
||||
jax_key = f"img/embedding/bias{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# positional embeddings
|
||||
jax_key = f"img/pos_embedding{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
|
||||
|
||||
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
||||
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
||||
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
||||
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
||||
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
||||
|
||||
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
||||
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
||||
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
||||
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
||||
|
||||
encoderblock_attention_0_key_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_key_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_value_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_value_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_query_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_query_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_out_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_out_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
|
||||
)
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
|
||||
] = encoderblock_layernorm0_scale[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
|
||||
] = encoderblock_layernorm0_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
|
||||
] = encoderblock_layernorm1_scale[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
|
||||
] = encoderblock_layernorm1_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
|
||||
] = encoderblock_mlp_dense0_kernel[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
|
||||
] = encoderblock_mlp_dense0_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
|
||||
] = encoderblock_mlp_dense1_kernel[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
|
||||
] = encoderblock_mlp_dense1_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
|
||||
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
|
||||
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
|
||||
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
|
||||
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
|
||||
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
|
||||
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
|
||||
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
|
||||
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
|
||||
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
||||
|
||||
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# multimodal projector
|
||||
jax_key = f"img/head/kernel{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
||||
|
||||
jax_key = f"img/head/bias{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# text decoder (gemma)
|
||||
jax_key = f"llm/embedder/input_embedding{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# pop the einsum attention + mlp representations
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
||||
|
||||
for i in range(config.text_config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = (
|
||||
llm_attention_q_einsum[i]
|
||||
.transpose(0, 2, 1)
|
||||
.reshape(
|
||||
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
||||
)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
|
||||
q_proj_weight_reshaped
|
||||
)
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
|
||||
k_proj_weight_reshaped
|
||||
)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
|
||||
v_proj_weight_reshaped
|
||||
)
|
||||
|
||||
o_proj_weight_reshaped = (
|
||||
llm_attention_attn_vec_einsum[i]
|
||||
.transpose(2, 0, 1)
|
||||
.reshape(
|
||||
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
||||
)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
|
||||
o_proj_weight_reshaped
|
||||
)
|
||||
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
|
||||
gate_proj_weight.transpose()
|
||||
)
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
|
||||
up_proj_weight.transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
|
||||
llm_mlp_linear[i].transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
|
||||
llm_input_layernorm[i]
|
||||
)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
|
||||
] = llm_post_attention_layernorm[i]
|
||||
|
||||
jax_key = f"llm/final_norm/scale{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
expert_dict = {}
|
||||
final_state_dict = {}
|
||||
|
||||
# Expert-related keys to extract (including pi05 Dense layer parameters)
|
||||
expert_keys = [
|
||||
f"llm/final_norm_1/scale{suffix}",
|
||||
f"llm/final_norm_1/Dense_0/bias{suffix}",
|
||||
f"llm/final_norm_1/Dense_0/kernel{suffix}",
|
||||
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
||||
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
||||
f"llm/layers/mlp_1/linear{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if key not in expert_keys:
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
expert_dict[key] = value
|
||||
|
||||
return final_state_dict, expert_dict
|
||||
|
||||
|
||||
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
|
||||
"""Convert Gemma JAX parameters to PyTorch format."""
|
||||
# Add missing attributes to config if they don't exist
|
||||
if not hasattr(config, "vocab_size"):
|
||||
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
|
||||
if not hasattr(config, "hidden_size"):
|
||||
config.hidden_size = config.width
|
||||
if not hasattr(config, "num_hidden_layers"):
|
||||
config.num_hidden_layers = config.depth
|
||||
if not hasattr(config, "num_attention_heads"):
|
||||
config.num_attention_heads = config.num_heads
|
||||
|
||||
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
||||
|
||||
# Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
|
||||
if "pi05" in checkpoint_dir:
|
||||
# Pi05 with adaptive normalization
|
||||
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
|
||||
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
|
||||
llm_input_layernorm_kernel = state_dict.pop(
|
||||
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
|
||||
)
|
||||
llm_post_attention_layernorm_kernel = state_dict.pop(
|
||||
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
|
||||
)
|
||||
else:
|
||||
# Regular pi0 with standard RMSNorm
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = (
|
||||
llm_attention_q_einsum[i]
|
||||
.transpose(0, 2, 1)
|
||||
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
|
||||
q_proj_weight_reshaped
|
||||
)
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
|
||||
k_proj_weight_reshaped
|
||||
)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
|
||||
v_proj_weight_reshaped
|
||||
)
|
||||
|
||||
o_proj_weight_reshaped = (
|
||||
llm_attention_attn_vec_einsum[i]
|
||||
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
.transpose(1, 0)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
|
||||
o_proj_weight_reshaped
|
||||
)
|
||||
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
|
||||
gate_proj_weight.transpose()
|
||||
)
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
|
||||
up_proj_weight.transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
|
||||
i
|
||||
].transpose()
|
||||
|
||||
if "pi05" in checkpoint_dir:
|
||||
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
|
||||
llm_input_layernorm_bias[i]
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
|
||||
llm_post_attention_layernorm_bias[i]
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
|
||||
llm_input_layernorm_kernel[i].transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
|
||||
llm_post_attention_layernorm_kernel[i].transpose()
|
||||
)
|
||||
else:
|
||||
# Regular pi0 with standard RMSNorm
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
|
||||
llm_input_layernorm[i]
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
|
||||
llm_post_attention_layernorm[i]
|
||||
)
|
||||
|
||||
# Handle final norm layer
|
||||
if "pi05" in checkpoint_dir:
|
||||
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
||||
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
|
||||
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
|
||||
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
|
||||
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
|
||||
else:
|
||||
# Regular pi0 with standard RMSNorm
|
||||
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
|
||||
f"llm/final_norm_{num_expert}/scale{suffix}"
|
||||
)
|
||||
|
||||
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
|
||||
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
|
||||
"""Load and process params by restoring via JAX model loader first.
|
||||
This respects dtype conversions that occur during model restore.
|
||||
"""
|
||||
# Use repository restore utility to load a pure dict of params (value suffix removed)
|
||||
params = openpi.models.model.restore_params(
|
||||
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
|
||||
)
|
||||
|
||||
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
|
||||
|
||||
|
||||
def load_jax_model_and_print_keys(checkpoint_dir: str):
|
||||
"""
|
||||
Load JAX model from checkpoint and print all parameter keys.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to the checkpoint directory
|
||||
"""
|
||||
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
|
||||
# Initialize checkpointer
|
||||
checkpointer = ocp.PyTreeCheckpointer()
|
||||
metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
|
||||
print(utils.array_tree_to_info(metadata))
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(
|
||||
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
|
||||
):
|
||||
"""
|
||||
Convert PI0 JAX checkpoint to PyTorch format.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to the JAX checkpoint
|
||||
precision: Model precision (float32, bfloat16, float16)
|
||||
output_path: Path to save the converted PyTorch model
|
||||
model_config: Model config
|
||||
"""
|
||||
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
|
||||
print(f"Model config: {model_config}")
|
||||
|
||||
# Break down orbax ckpts by restoring via JAX to respect dtype
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
|
||||
|
||||
# Process projection params
|
||||
if model_config.pi05:
|
||||
keys = [
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"time_mlp_in",
|
||||
"time_mlp_out",
|
||||
]
|
||||
else:
|
||||
keys = [
|
||||
"state_proj",
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"action_time_mlp_in",
|
||||
"action_time_mlp_out",
|
||||
]
|
||||
|
||||
projection_params = {}
|
||||
for key in keys:
|
||||
kernel_params = initial_params["projection_params"][key]["kernel"]
|
||||
bias_params = initial_params["projection_params"][key]["bias"]
|
||||
if isinstance(kernel_params, dict):
|
||||
weight = kernel_params["value"]
|
||||
bias = bias_params["value"]
|
||||
else:
|
||||
weight = kernel_params
|
||||
bias = bias_params
|
||||
|
||||
pytorch_weight_key = f"{key}.weight"
|
||||
pytorch_bias_key = f"{key}.bias"
|
||||
|
||||
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
|
||||
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
|
||||
|
||||
# Create configs based on checkpoint path
|
||||
# All models use the same PaliGemma config structure
|
||||
class PaliGemmaConfig:
|
||||
def __init__(self):
|
||||
self.vision_config = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"hidden_size": 1152,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"intermediate_size": 4304,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
},
|
||||
)()
|
||||
self.text_config = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"hidden_size": 2048,
|
||||
"num_hidden_layers": 18,
|
||||
"num_attention_heads": 8,
|
||||
"head_dim": 256,
|
||||
"intermediate_size": 16384,
|
||||
},
|
||||
)()
|
||||
|
||||
paligemma_config = PaliGemmaConfig()
|
||||
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
|
||||
|
||||
# Process PaliGemma weights
|
||||
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
|
||||
|
||||
# Process Gemma weights from expert_params
|
||||
gemma_params = slice_gemma_state_dict(
|
||||
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
|
||||
)
|
||||
|
||||
# Instantiate model
|
||||
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
|
||||
|
||||
# Combine all parameters (no prefix needed for our model structure)
|
||||
all_params = {**paligemma_params, **gemma_params, **projection_params}
|
||||
|
||||
# Load state dict
|
||||
pi0_model.load_state_dict(all_params, strict=False)
|
||||
|
||||
if precision == "float32":
|
||||
pi0_model = pi0_model.to(torch.float32)
|
||||
elif precision == "bfloat16":
|
||||
pi0_model = pi0_model.to(torch.bfloat16)
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Save the converted model using safetensors
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# Save model weights as SafeTensors using save_model to handle tied weights
|
||||
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
|
||||
|
||||
# Copy assets folder if it exists
|
||||
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
|
||||
if assets_source.exists():
|
||||
assets_dest = pathlib.Path(output_path) / "assets"
|
||||
if assets_dest.exists():
|
||||
shutil.rmtree(assets_dest)
|
||||
shutil.copytree(assets_source, assets_dest)
|
||||
|
||||
# Save config as JSON for reference
|
||||
config_dict = {
|
||||
"action_dim": model_config.action_dim,
|
||||
"action_horizon": model_config.action_horizon,
|
||||
"paligemma_variant": model_config.paligemma_variant,
|
||||
"action_expert_variant": model_config.action_expert_variant,
|
||||
"precision": precision,
|
||||
}
|
||||
with open(os.path.join(output_path, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
print("Model conversion completed successfully!")
|
||||
print(f"Model saved to {output_path}")
|
||||
|
||||
|
||||
def main(
|
||||
checkpoint_dir: str,
|
||||
config_name: str,
|
||||
output_path: str | None = None,
|
||||
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
|
||||
*,
|
||||
inspect_only: bool = False,
|
||||
):
|
||||
"""Load JAX model and optionally convert to PyTorch.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to the JAX checkpoint directory
|
||||
output_path: Path to save converted PyTorch model (required for conversion)
|
||||
precision: Precision for model conversion
|
||||
inspect_only: Only inspect parameter keys, don't convert
|
||||
"""
|
||||
model_config = _config.get_config(config_name).model
|
||||
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
|
||||
raise ValueError(f"Config {config_name} is not a Pi0Config")
|
||||
if inspect_only:
|
||||
load_jax_model_and_print_keys(checkpoint_dir)
|
||||
else:
|
||||
if not output_path:
|
||||
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
|
||||
return
|
||||
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
84
policy/openpi-InternData-A1/examples/droid/README.md
Normal file
84
policy/openpi-InternData-A1/examples/droid/README.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# DROID Policies in openpi
|
||||
|
||||
We offer instructions for:
|
||||
- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
|
||||
- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
|
||||
- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
|
||||
- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
|
||||
|
||||
## Running DROID Inference
|
||||
|
||||
This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
|
||||
|
||||
|
||||
### Step 1: Start a policy server
|
||||
|
||||
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
|
||||
|
||||
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
|
||||
2. Start the OpenPI server via the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
|
||||
```
|
||||
|
||||
You can also run the equivalent command below:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env=DROID
|
||||
```
|
||||
|
||||
### Step 2: Run the DROID robot
|
||||
|
||||
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
|
||||
2. On the control laptop, activate your DROID conda environment.
|
||||
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
||||
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
||||
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
||||
|
||||
```bash
|
||||
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
|
||||
```
|
||||
|
||||
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
|
||||
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
|
||||
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
|
||||
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
|
||||
|
||||
|
||||
## Running Other Policies
|
||||
|
||||
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
|
||||
|
||||
```
|
||||
# Train from pi0-FAST, using FAST tokenizer
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
||||
|
||||
# Train from pi0, using flow matching
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
|
||||
|
||||
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
|
||||
|
||||
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
|
||||
|
||||
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
|
||||
|
||||
# Trained from PaliGemma, using FSQ tokenizer.
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
|
||||
|
||||
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
|
||||
```
|
||||
|
||||
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).
|
||||
106
policy/openpi-InternData-A1/examples/droid/README_train.md
Normal file
106
policy/openpi-InternData-A1/examples/droid/README_train.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# Training on DROID
|
||||
|
||||
Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
|
||||
(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
|
||||
|
||||
In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
|
||||
for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
|
||||
|
||||
## Install
|
||||
|
||||
We need a few additional dependencies for RLDS data loading. Run:
|
||||
```bash
|
||||
uv sync --group rlds
|
||||
```
|
||||
|
||||
## Download DROID dataset
|
||||
|
||||
You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
|
||||
```
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
|
||||
```
|
||||
|
||||
Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
|
||||
|
||||
You will need 1.8TB of disk storage to download the DROID RLDS dataset.
|
||||
|
||||
## Run
|
||||
|
||||
First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
|
||||
|
||||
Then, compute normalization statistics (this will take ~10 minutes):
|
||||
```bash
|
||||
uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
|
||||
```
|
||||
|
||||
Run training:
|
||||
```bash
|
||||
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
|
||||
```
|
||||
|
||||
**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
|
||||
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
|
||||
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
|
||||
|
||||
|
||||
## Compute Requirements
|
||||
|
||||
Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
|
||||
If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
|
||||
|
||||
We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
|
||||
|
||||
|
||||
## Data Filtering
|
||||
|
||||
Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
|
||||
|
||||
By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
|
||||
|
||||
**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
|
||||
|
||||
## RoboArena
|
||||
|
||||
Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
|
||||
|
||||
If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
|
||||
|
||||
|
||||
# Fine-Tuning on Custom DROID Datasets
|
||||
|
||||
Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
|
||||
|
||||
Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
|
||||
|
||||
|
||||
## Step 1: Converting your custom DROID dataset to LeRobot
|
||||
|
||||
We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
|
||||
```
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
|
||||
```
|
||||
|
||||
We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
|
||||
```
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
|
||||
```
|
||||
|
||||
For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
|
||||
|
||||
Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
|
||||
```
|
||||
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
|
||||
```
|
||||
|
||||
## Step 2: Run fine-tuning with your custom dataset
|
||||
|
||||
Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
|
||||
You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
|
||||
|
||||
To launch training:
|
||||
```
|
||||
uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
|
||||
```
|
||||
|
||||
Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
|
||||
that should be sampled during training (all others are filtered out).
|
||||
|
||||
Filtering logic:
|
||||
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
|
||||
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
|
||||
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
|
||||
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
|
||||
filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
|
||||
|
||||
This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
|
||||
yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
from tqdm import tqdm
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
|
||||
|
||||
builder = tfds.builder_from_directory(
|
||||
# path to the `droid` directory (not its parent)
|
||||
builder_dir="<path_to_droid_dataset_tfds_files>",
|
||||
)
|
||||
ds = builder.as_dataset(split="train", shuffle_files=False)
|
||||
tf.data.experimental.ignore_errors(ds)
|
||||
|
||||
keep_ranges_path = "<path_to_where_to_save_the_json>"
|
||||
|
||||
min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
|
||||
min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
|
||||
filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
|
||||
|
||||
keep_ranges_map = {}
|
||||
if Path(keep_ranges_path).exists():
|
||||
with Path(keep_ranges_path).open("r") as f:
|
||||
keep_ranges_map = json.load(f)
|
||||
print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
|
||||
|
||||
for ep_idx, ep in enumerate(tqdm(ds)):
|
||||
recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
|
||||
file_path = ep["episode_metadata"]["file_path"].numpy().decode()
|
||||
|
||||
key = f"{recording_folderpath}--{file_path}"
|
||||
if key in keep_ranges_map:
|
||||
continue
|
||||
|
||||
joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
|
||||
joint_velocities = np.array(joint_velocities)
|
||||
|
||||
is_idle_array = np.hstack(
|
||||
[np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
|
||||
)
|
||||
|
||||
# Find what steps go from idle to non-idle and vice-versa
|
||||
is_idle_padded = np.concatenate(
|
||||
[[False], is_idle_array, [False]]
|
||||
) # Start and end with False, so idle at first step is a start of motion
|
||||
|
||||
is_idle_diff = np.diff(is_idle_padded.astype(int))
|
||||
is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
|
||||
is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
|
||||
|
||||
# Find which steps correspond to idle segments of length at least min_idle_len
|
||||
true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
|
||||
is_idle_true_starts = is_idle_true_starts[true_segment_masks]
|
||||
is_idle_true_ends = is_idle_true_ends[true_segment_masks]
|
||||
|
||||
keep_mask = np.ones(len(joint_velocities), dtype=bool)
|
||||
for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
|
||||
keep_mask[start:end] = False
|
||||
|
||||
# Get all non-idle ranges of at least 16
|
||||
# Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
|
||||
keep_padded = np.concatenate([[False], keep_mask, [False]])
|
||||
|
||||
keep_diff = np.diff(keep_padded.astype(int))
|
||||
keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
|
||||
keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
|
||||
|
||||
# Find which steps correspond to non-idle segments of length at least min_non_idle_len
|
||||
true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
|
||||
keep_true_starts = keep_true_starts[true_segment_masks]
|
||||
keep_true_ends = keep_true_ends[true_segment_masks]
|
||||
|
||||
# Add mapping from episode unique ID key to list of non-idle ranges to keep
|
||||
keep_ranges_map[key] = []
|
||||
for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
|
||||
keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
|
||||
|
||||
if ep_idx % 1000 == 0:
|
||||
with Path(keep_ranges_path).open("w") as f:
|
||||
json.dump(keep_ranges_map, f)
|
||||
|
||||
print("Done!")
|
||||
with Path(keep_ranges_path).open("w") as f:
|
||||
json.dump(keep_ranges_map, f)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
|
||||
|
||||
Usage:
|
||||
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
|
||||
|
||||
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
||||
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
||||
|
||||
The resulting dataset will get saved to the $LEROBOT_HOME directory.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import tyro
|
||||
|
||||
REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
|
||||
|
||||
|
||||
def resize_image(image, size):
|
||||
image = Image.fromarray(image)
|
||||
return np.array(image.resize(size, resample=Image.BICUBIC))
|
||||
|
||||
|
||||
def main(data_dir: str, *, push_to_hub: bool = False):
|
||||
# Clean up any existing dataset in the output directory
|
||||
output_path = HF_LEROBOT_HOME / REPO_NAME
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
data_dir = Path(data_dir)
|
||||
|
||||
# Create LeRobot dataset, define features to store
|
||||
# We will follow the DROID data naming conventions here.
|
||||
# LeRobot assumes that dtype of image data is `image`
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=REPO_NAME,
|
||||
robot_type="panda",
|
||||
fps=15, # DROID data is typically recorded at 15fps
|
||||
features={
|
||||
# We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
|
||||
"exterior_image_1_left": {
|
||||
"dtype": "image",
|
||||
"shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"exterior_image_2_left": {
|
||||
"dtype": "image",
|
||||
"shape": (180, 320, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"wrist_image_left": {
|
||||
"dtype": "image",
|
||||
"shape": (180, 320, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"joint_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": ["joint_position"],
|
||||
},
|
||||
"gripper_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": ["gripper_position"],
|
||||
},
|
||||
"actions": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
|
||||
"names": ["actions"],
|
||||
},
|
||||
},
|
||||
image_writer_threads=10,
|
||||
image_writer_processes=5,
|
||||
)
|
||||
|
||||
# Load language annotations
|
||||
# Note: we load the DROID language annotations for this example, but you can manually define them for your own data
|
||||
with (data_dir / "aggregated-annotations-030724.json").open() as f:
|
||||
language_annotations = json.load(f)
|
||||
|
||||
# Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
|
||||
# We assume the following directory structure:
|
||||
# RAW_DROID_PATH/
|
||||
# - <...>/
|
||||
# - recordings/
|
||||
# - MP4/
|
||||
# - <camera_id>.mp4 # single-view video of left stereo pair camera
|
||||
# - trajectory.hdf5
|
||||
# - <...>/
|
||||
episode_paths = list(data_dir.glob("**/trajectory.h5"))
|
||||
print(f"Found {len(episode_paths)} episodes for conversion")
|
||||
|
||||
# We will loop over each dataset_name and write episodes to the LeRobot dataset
|
||||
for episode_path in tqdm(episode_paths, desc="Converting episodes"):
|
||||
# Load raw data
|
||||
recording_folderpath = episode_path.parent / "recordings" / "MP4"
|
||||
trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
|
||||
|
||||
# To load the language instruction, we need to parse out the episode_id from the metadata file
|
||||
# Again, you can modify this step for your own data, to load your own language instructions
|
||||
metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
|
||||
episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
|
||||
language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
|
||||
"language_instruction1"
|
||||
]
|
||||
print(f"Converting episode with language instruction: {language_instruction}")
|
||||
|
||||
# Write to LeRobot dataset
|
||||
for step in trajectory:
|
||||
camera_type_dict = step["observation"]["camera_type"]
|
||||
wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
|
||||
exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
|
||||
dataset.add_frame(
|
||||
{
|
||||
# Note: need to flip BGR --> RGB for loaded images
|
||||
"exterior_image_1_left": resize_image(
|
||||
step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
|
||||
),
|
||||
"exterior_image_2_left": resize_image(
|
||||
step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
|
||||
),
|
||||
"wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
|
||||
"joint_position": np.asarray(
|
||||
step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
|
||||
),
|
||||
"gripper_position": np.asarray(
|
||||
step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
|
||||
),
|
||||
# Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
|
||||
"actions": np.concatenate(
|
||||
[step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
|
||||
),
|
||||
"task": language_instruction,
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Optionally push to the Hugging Face Hub
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(
|
||||
tags=["libero", "panda", "rlds"],
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
)
|
||||
|
||||
|
||||
##########################################################################################################
|
||||
################ The rest of this file are functions to parse the raw DROID data #########################
|
||||
################ You don't need to worry about understanding this part #########################
|
||||
################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
|
||||
##########################################################################################################
|
||||
|
||||
|
||||
camera_type_dict = {
|
||||
"hand_camera_id": 0,
|
||||
"varied_camera_1_id": 1,
|
||||
"varied_camera_2_id": 1,
|
||||
}
|
||||
|
||||
camera_type_to_string_dict = {
|
||||
0: "hand_camera",
|
||||
1: "varied_camera",
|
||||
2: "fixed_camera",
|
||||
}
|
||||
|
||||
|
||||
def get_camera_type(cam_id):
|
||||
if cam_id not in camera_type_dict:
|
||||
return None
|
||||
type_int = camera_type_dict[cam_id]
|
||||
return camera_type_to_string_dict[type_int]
|
||||
|
||||
|
||||
class MP4Reader:
|
||||
def __init__(self, filepath, serial_number):
|
||||
# Save Parameters #
|
||||
self.serial_number = serial_number
|
||||
self._index = 0
|
||||
|
||||
# Open Video Reader #
|
||||
self._mp4_reader = cv2.VideoCapture(filepath)
|
||||
if not self._mp4_reader.isOpened():
|
||||
raise RuntimeError("Corrupted MP4 File")
|
||||
|
||||
def set_reading_parameters(
|
||||
self,
|
||||
image=True, # noqa: FBT002
|
||||
concatenate_images=False, # noqa: FBT002
|
||||
resolution=(0, 0),
|
||||
resize_func=None,
|
||||
):
|
||||
# Save Parameters #
|
||||
self.image = image
|
||||
self.concatenate_images = concatenate_images
|
||||
self.resolution = resolution
|
||||
self.resize_func = cv2.resize
|
||||
self.skip_reading = not image
|
||||
if self.skip_reading:
|
||||
return
|
||||
|
||||
def get_frame_resolution(self):
|
||||
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
|
||||
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
|
||||
return (width, height)
|
||||
|
||||
def get_frame_count(self):
|
||||
if self.skip_reading:
|
||||
return 0
|
||||
return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
|
||||
|
||||
def set_frame_index(self, index):
|
||||
if self.skip_reading:
|
||||
return
|
||||
|
||||
if index < self._index:
|
||||
self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
|
||||
self._index = index
|
||||
|
||||
while self._index < index:
|
||||
self.read_camera(ignore_data=True)
|
||||
|
||||
def _process_frame(self, frame):
|
||||
frame = copy.deepcopy(frame)
|
||||
if self.resolution == (0, 0):
|
||||
return frame
|
||||
return self.resize_func(frame, self.resolution)
|
||||
|
||||
def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
|
||||
# Skip if Read Unnecesary #
|
||||
if self.skip_reading:
|
||||
return {}
|
||||
|
||||
# Read Camera #
|
||||
success, frame = self._mp4_reader.read()
|
||||
|
||||
self._index += 1
|
||||
if not success:
|
||||
return None
|
||||
if ignore_data:
|
||||
return None
|
||||
|
||||
# Return Data #
|
||||
data_dict = {}
|
||||
|
||||
if self.concatenate_images or "stereo" not in self.serial_number:
|
||||
data_dict["image"] = {self.serial_number: self._process_frame(frame)}
|
||||
else:
|
||||
single_width = frame.shape[1] // 2
|
||||
data_dict["image"] = {
|
||||
self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
|
||||
self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
|
||||
}
|
||||
|
||||
return data_dict
|
||||
|
||||
def disable_camera(self):
|
||||
if hasattr(self, "_mp4_reader"):
|
||||
self._mp4_reader.release()
|
||||
|
||||
|
||||
class RecordedMultiCameraWrapper:
|
||||
def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
|
||||
# Save Camera Info #
|
||||
self.camera_kwargs = camera_kwargs
|
||||
|
||||
# Open Camera Readers #
|
||||
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
|
||||
all_filepaths = mp4_filepaths
|
||||
|
||||
self.camera_dict = {}
|
||||
for f in all_filepaths:
|
||||
serial_number = f.split("/")[-1][:-4]
|
||||
cam_type = get_camera_type(serial_number)
|
||||
camera_kwargs.get(cam_type, {})
|
||||
|
||||
if f.endswith(".mp4"):
|
||||
Reader = MP4Reader # noqa: N806
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.camera_dict[serial_number] = Reader(f, serial_number)
|
||||
|
||||
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
|
||||
full_obs_dict = defaultdict(dict)
|
||||
|
||||
# Read Cameras In Randomized Order #
|
||||
all_cam_ids = list(self.camera_dict.keys())
|
||||
# random.shuffle(all_cam_ids)
|
||||
|
||||
for cam_id in all_cam_ids:
|
||||
if "stereo" in cam_id:
|
||||
continue
|
||||
try:
|
||||
cam_type = camera_type_dict[cam_id]
|
||||
except KeyError:
|
||||
print(f"{self.camera_dict} -- {camera_type_dict}")
|
||||
raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
|
||||
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
|
||||
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
|
||||
|
||||
timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
|
||||
if index is not None:
|
||||
self.camera_dict[cam_id].set_frame_index(index)
|
||||
|
||||
data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
|
||||
|
||||
# Process Returned Data #
|
||||
if data_dict is None:
|
||||
return None
|
||||
for key in data_dict:
|
||||
full_obs_dict[key].update(data_dict[key])
|
||||
|
||||
return full_obs_dict
|
||||
|
||||
|
||||
def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
|
||||
length = None
|
||||
|
||||
for key in hdf5_file:
|
||||
if key in keys_to_ignore:
|
||||
continue
|
||||
|
||||
curr_data = hdf5_file[key]
|
||||
if isinstance(curr_data, h5py.Group):
|
||||
curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
|
||||
elif isinstance(curr_data, h5py.Dataset):
|
||||
curr_length = len(curr_data)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if length is None:
|
||||
length = curr_length
|
||||
assert curr_length == length
|
||||
|
||||
return length
|
||||
|
||||
|
||||
def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
|
||||
data_dict = {}
|
||||
|
||||
for key in hdf5_file:
|
||||
if key in keys_to_ignore:
|
||||
continue
|
||||
|
||||
curr_data = hdf5_file[key]
|
||||
if isinstance(curr_data, h5py.Group):
|
||||
data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
|
||||
elif isinstance(curr_data, h5py.Dataset):
|
||||
data_dict[key] = curr_data[index]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
class TrajectoryReader:
|
||||
def __init__(self, filepath, read_images=True): # noqa: FBT002
|
||||
self._hdf5_file = h5py.File(filepath, "r")
|
||||
is_video_folder = "observations/videos" in self._hdf5_file
|
||||
self._read_images = read_images and is_video_folder
|
||||
self._length = get_hdf5_length(self._hdf5_file)
|
||||
self._video_readers = {}
|
||||
self._index = 0
|
||||
|
||||
def length(self):
|
||||
return self._length
|
||||
|
||||
def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
|
||||
# Make Sure We Read Within Range #
|
||||
if index is None:
|
||||
index = self._index
|
||||
else:
|
||||
assert not self._read_images
|
||||
self._index = index
|
||||
assert index < self._length
|
||||
|
||||
# Load Low Dimensional Data #
|
||||
keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
|
||||
timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
|
||||
|
||||
# Increment Read Index #
|
||||
self._index += 1
|
||||
|
||||
# Return Timestep #
|
||||
return timestep
|
||||
|
||||
def close(self):
|
||||
self._hdf5_file.close()
|
||||
|
||||
|
||||
def load_trajectory(
|
||||
filepath=None,
|
||||
read_cameras=True, # noqa: FBT002
|
||||
recording_folderpath=None,
|
||||
camera_kwargs={}, # noqa: B006
|
||||
remove_skipped_steps=False, # noqa: FBT002
|
||||
num_samples_per_traj=None,
|
||||
num_samples_per_traj_coeff=1.5,
|
||||
):
|
||||
read_recording_folderpath = read_cameras and (recording_folderpath is not None)
|
||||
|
||||
traj_reader = TrajectoryReader(filepath)
|
||||
if read_recording_folderpath:
|
||||
camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
|
||||
|
||||
horizon = traj_reader.length()
|
||||
timestep_list = []
|
||||
|
||||
# Choose Timesteps To Save #
|
||||
if num_samples_per_traj:
|
||||
num_to_save = num_samples_per_traj
|
||||
if remove_skipped_steps:
|
||||
num_to_save = int(num_to_save * num_samples_per_traj_coeff)
|
||||
max_size = min(num_to_save, horizon)
|
||||
indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
|
||||
else:
|
||||
indices_to_save = np.arange(horizon)
|
||||
|
||||
# Iterate Over Trajectory #
|
||||
for i in indices_to_save:
|
||||
# Get HDF5 Data #
|
||||
timestep = traj_reader.read_timestep(index=i)
|
||||
|
||||
# If Applicable, Get Recorded Data #
|
||||
if read_recording_folderpath:
|
||||
timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
|
||||
camera_type_dict = {
|
||||
k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
|
||||
}
|
||||
camera_obs = camera_reader.read_cameras(
|
||||
index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
|
||||
)
|
||||
camera_failed = camera_obs is None
|
||||
|
||||
# Add Data To Timestep If Successful #
|
||||
if camera_failed:
|
||||
break
|
||||
timestep["observation"].update(camera_obs)
|
||||
|
||||
# Filter Steps #
|
||||
step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
|
||||
delete_skipped_step = step_skipped and remove_skipped_steps
|
||||
|
||||
# Save Filtered Timesteps #
|
||||
if delete_skipped_step:
|
||||
del timestep
|
||||
else:
|
||||
timestep_list.append(timestep)
|
||||
|
||||
# Remove Extra Transitions #
|
||||
timestep_list = np.array(timestep_list)
|
||||
if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
|
||||
ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
|
||||
timestep_list = timestep_list[ind_to_keep]
|
||||
|
||||
# Close Readers #
|
||||
traj_reader.close()
|
||||
|
||||
# Return Data #
|
||||
return timestep_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
246
policy/openpi-InternData-A1/examples/droid/main.py
Normal file
246
policy/openpi-InternData-A1/examples/droid/main.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# ruff: noqa
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import datetime
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from droid.robot_env import RobotEnv
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
# DROID data collection frequency -- we slow down execution to match this frequency
|
||||
DROID_CONTROL_FREQUENCY = 15
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
# Hardware parameters
|
||||
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
|
||||
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
|
||||
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
|
||||
|
||||
# Policy parameters
|
||||
external_camera: str | None = (
|
||||
None # which external camera should be fed to the policy, choose from ["left", "right"]
|
||||
)
|
||||
|
||||
# Rollout parameters
|
||||
max_timesteps: int = 600
|
||||
# How many actions to execute from a predicted action chunk before querying policy server again
|
||||
# 8 is usually a good default (equals 0.5 seconds of action execution).
|
||||
open_loop_horizon: int = 8
|
||||
|
||||
# Remote server parameters
|
||||
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
|
||||
remote_port: int = (
|
||||
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
|
||||
)
|
||||
|
||||
|
||||
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
||||
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
||||
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
||||
@contextlib.contextmanager
|
||||
def prevent_keyboard_interrupt():
|
||||
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
||||
interrupted = False
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def handler(signum, frame):
|
||||
nonlocal interrupted
|
||||
interrupted = True
|
||||
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if interrupted:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
# Make sure external camera is specified by user -- we only use one external camera for the policy
|
||||
assert (
|
||||
args.external_camera is not None and args.external_camera in ["left", "right"]
|
||||
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
||||
|
||||
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
|
||||
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
||||
print("Created the droid env!")
|
||||
|
||||
# Connect to the policy server
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
||||
|
||||
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
||||
|
||||
while True:
|
||||
instruction = input("Enter instruction: ")
|
||||
|
||||
# Rollout parameters
|
||||
actions_from_chunk_completed = 0
|
||||
pred_action_chunk = None
|
||||
|
||||
# Prepare to save video of rollout
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
||||
video = []
|
||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||
print("Running rollout... press Ctrl+C to stop early.")
|
||||
for t_step in bar:
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Get the current observation
|
||||
curr_obs = _extract_observation(
|
||||
args,
|
||||
env.get_observation(),
|
||||
# Save the first observation to disk
|
||||
save_to_disk=t_step == 0,
|
||||
)
|
||||
|
||||
video.append(curr_obs[f"{args.external_camera}_image"])
|
||||
|
||||
# Send websocket request to policy server if it's time to predict a new chunk
|
||||
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
||||
actions_from_chunk_completed = 0
|
||||
|
||||
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
|
||||
# and improve latency.
|
||||
request_data = {
|
||||
"observation/exterior_image_1_left": image_tools.resize_with_pad(
|
||||
curr_obs[f"{args.external_camera}_image"], 224, 224
|
||||
),
|
||||
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
||||
"observation/joint_position": curr_obs["joint_position"],
|
||||
"observation/gripper_position": curr_obs["gripper_position"],
|
||||
"prompt": instruction,
|
||||
}
|
||||
|
||||
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
||||
# Ctrl+C will be handled after the server call is complete
|
||||
with prevent_keyboard_interrupt():
|
||||
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
|
||||
pred_action_chunk = policy_client.infer(request_data)["actions"]
|
||||
assert pred_action_chunk.shape == (10, 8)
|
||||
|
||||
# Select current action to execute from chunk
|
||||
action = pred_action_chunk[actions_from_chunk_completed]
|
||||
actions_from_chunk_completed += 1
|
||||
|
||||
# Binarize gripper action
|
||||
if action[-1].item() > 0.5:
|
||||
# action[-1] = 1.0
|
||||
action = np.concatenate([action[:-1], np.ones((1,))])
|
||||
else:
|
||||
# action[-1] = 0.0
|
||||
action = np.concatenate([action[:-1], np.zeros((1,))])
|
||||
|
||||
# clip all dimensions of action to [-1, 1]
|
||||
action = np.clip(action, -1, 1)
|
||||
|
||||
env.step(action)
|
||||
|
||||
# Sleep to match DROID data collection frequency
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
||||
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
video = np.stack(video)
|
||||
save_filename = "video_" + timestamp
|
||||
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
||||
|
||||
success: str | float | None = None
|
||||
while not isinstance(success, float):
|
||||
success = input(
|
||||
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
||||
)
|
||||
if success == "y":
|
||||
success = 1.0
|
||||
elif success == "n":
|
||||
success = 0.0
|
||||
|
||||
success = float(success) / 100
|
||||
if not (0 <= success <= 1):
|
||||
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
||||
|
||||
df = df.append(
|
||||
{
|
||||
"success": success,
|
||||
"duration": t_step,
|
||||
"video_filename": save_filename,
|
||||
},
|
||||
ignore_index=True,
|
||||
)
|
||||
|
||||
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
||||
break
|
||||
env.reset()
|
||||
|
||||
os.makedirs("results", exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
||||
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
||||
df.to_csv(csv_filename)
|
||||
print(f"Results saved to {csv_filename}")
|
||||
|
||||
|
||||
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
||||
image_observations = obs_dict["image"]
|
||||
left_image, right_image, wrist_image = None, None, None
|
||||
for key in image_observations:
|
||||
# Note the "left" below refers to the left camera in the stereo pair.
|
||||
# The model is only trained on left stereo cams, so we only feed those.
|
||||
if args.left_camera_id in key and "left" in key:
|
||||
left_image = image_observations[key]
|
||||
elif args.right_camera_id in key and "left" in key:
|
||||
right_image = image_observations[key]
|
||||
elif args.wrist_camera_id in key and "left" in key:
|
||||
wrist_image = image_observations[key]
|
||||
|
||||
# Drop the alpha dimension
|
||||
left_image = left_image[..., :3]
|
||||
right_image = right_image[..., :3]
|
||||
wrist_image = wrist_image[..., :3]
|
||||
|
||||
# Convert to RGB
|
||||
left_image = left_image[..., ::-1]
|
||||
right_image = right_image[..., ::-1]
|
||||
wrist_image = wrist_image[..., ::-1]
|
||||
|
||||
# In addition to image observations, also capture the proprioceptive state
|
||||
robot_state = obs_dict["robot_state"]
|
||||
cartesian_position = np.array(robot_state["cartesian_position"])
|
||||
joint_position = np.array(robot_state["joint_positions"])
|
||||
gripper_position = np.array([robot_state["gripper_position"]])
|
||||
|
||||
# Save the images to disk so that they can be viewed live while the robot is running
|
||||
# Create one combined image to make live viewing easy
|
||||
if save_to_disk:
|
||||
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
||||
combined_image = Image.fromarray(combined_image)
|
||||
combined_image.save("robot_camera_views.png")
|
||||
|
||||
return {
|
||||
"left_image": left_image,
|
||||
"right_image": right_image,
|
||||
"wrist_image": wrist_image,
|
||||
"cartesian_position": cartesian_position,
|
||||
"joint_position": joint_position,
|
||||
"gripper_position": gripper_position,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args: Args = tyro.cli(Args)
|
||||
main(args)
|
||||
137
policy/openpi-InternData-A1/examples/inference.ipynb
Normal file
137
policy/openpi-InternData-A1/examples/inference.ipynb
Normal file
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import dataclasses\n",
|
||||
"\n",
|
||||
"import jax\n",
|
||||
"\n",
|
||||
"from openpi.models import model as _model\n",
|
||||
"from openpi.policies import droid_policy\n",
|
||||
"from openpi.policies import policy_config as _policy_config\n",
|
||||
"from openpi.shared import download\n",
|
||||
"from openpi.training import config as _config\n",
|
||||
"from openpi.training import data_loader as _data_loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Policy inference\n",
|
||||
"\n",
|
||||
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
||||
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
||||
"\n",
|
||||
"# Create a trained policy.\n",
|
||||
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
||||
"\n",
|
||||
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
|
||||
"example = droid_policy.make_droid_example()\n",
|
||||
"result = policy.infer(example)\n",
|
||||
"\n",
|
||||
"# Delete the policy to free up memory.\n",
|
||||
"del policy\n",
|
||||
"\n",
|
||||
"print(\"Actions shape:\", result[\"actions\"].shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Working with a live model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_aloha_sim\")\n",
|
||||
"\n",
|
||||
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
|
||||
"key = jax.random.key(0)\n",
|
||||
"\n",
|
||||
"# Create a model from the checkpoint.\n",
|
||||
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
|
||||
"\n",
|
||||
"# We can create fake observations and actions to test the model.\n",
|
||||
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reduce the batch size to reduce memory usage.\n",
|
||||
"config = dataclasses.replace(config, batch_size=2)\n",
|
||||
"\n",
|
||||
"# Load a single batch of data. This is the same data that will be used during training.\n",
|
||||
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
|
||||
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
|
||||
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
|
||||
"obs, act = next(iter(loader))\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"\n",
|
||||
"# Delete the model to free up memory.\n",
|
||||
"del model\n",
|
||||
"\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
59
policy/openpi-InternData-A1/examples/libero/Dockerfile
Normal file
59
policy/openpi-InternData-A1/examples/libero/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Dockerfile for the LIBERO benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t libero -f examples/libero/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
make \
|
||||
g++ \
|
||||
clang \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
|
||||
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
|
||||
|
||||
# Create a default config file to avoid an input prompt from LIBERO's init script.
|
||||
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
|
||||
ENV LIBERO_CONFIG_PATH=/tmp/libero
|
||||
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
|
||||
benchmark_root: /app/third_party/libero/libero/libero
|
||||
bddl_files: /app/third_party/libero/libero/libero/bddl_files
|
||||
init_states: /app/third_party/libero/libero/libero/init_files
|
||||
datasets: /app/third_party/libero/libero/datasets
|
||||
assets: /app/third_party/libero/libero/libero/assets
|
||||
EOF
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]
|
||||
71
policy/openpi-InternData-A1/examples/libero/README.md
Normal file
71
policy/openpi-InternData-A1/examples/libero/README.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# LIBERO Benchmark
|
||||
|
||||
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
|
||||
|
||||
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
|
||||
|
||||
This example requires git submodules to be initialized. Don't forget to run:
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
## With Docker (recommended)
|
||||
|
||||
```bash
|
||||
# Grant access to the X11 server:
|
||||
sudo xhost +local:docker
|
||||
|
||||
# To run with the default checkpoint and task suite:
|
||||
SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
|
||||
|
||||
# To run with glx for Mujoco instead (use this if you have egl errors):
|
||||
MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
|
||||
```
|
||||
|
||||
You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
|
||||
For example:
|
||||
|
||||
```bash
|
||||
# To load a custom checkpoint (located in the top-level openpi/ directory):
|
||||
export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
|
||||
|
||||
# To run the libero_10 task suite:
|
||||
export CLIENT_ARGS="--args.task-suite-name libero_10"
|
||||
```
|
||||
|
||||
## Without Docker (not recommended)
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.8 examples/libero/.venv
|
||||
source examples/libero/.venv/bin/activate
|
||||
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
uv pip install -e packages/openpi-client
|
||||
uv pip install -e third_party/libero
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
|
||||
|
||||
# Run the simulation
|
||||
python examples/libero/main.py
|
||||
|
||||
# To run with glx for Mujoco instead (use this if you have egl errors):
|
||||
MUJOCO_GL=glx python examples/libero/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env LIBERO
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
|
||||
checkpoint was trained in openpi with the `pi05_libero` config.
|
||||
|
||||
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|
||||
|-------|---------------|---------------|-------------|-----------|---------|
|
||||
| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85
|
||||
54
policy/openpi-InternData-A1/examples/libero/compose.yml
Normal file
54
policy/openpi-InternData-A1/examples/libero/compose.yml
Normal file
@@ -0,0 +1,54 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/libero/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: libero
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/libero/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix:ro
|
||||
environment:
|
||||
- CLIENT_ARGS
|
||||
- DISPLAY=$DISPLAY
|
||||
- MUJOCO_GL=${MUJOCO_GL:-egl}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Minimal example script for converting a dataset to LeRobot format.
|
||||
|
||||
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
|
||||
modified for any other data you have saved in a custom format.
|
||||
|
||||
Usage:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
|
||||
|
||||
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
||||
|
||||
Note: to run the script, you need to install tensorflow_datasets:
|
||||
`uv pip install tensorflow tensorflow_datasets`
|
||||
|
||||
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
|
||||
The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
|
||||
Running this conversion script will take approximately 30 minutes.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import tensorflow_datasets as tfds
|
||||
import tyro
|
||||
|
||||
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
|
||||
RAW_DATASET_NAMES = [
|
||||
"libero_10_no_noops",
|
||||
"libero_goal_no_noops",
|
||||
"libero_object_no_noops",
|
||||
"libero_spatial_no_noops",
|
||||
] # For simplicity we will combine multiple Libero datasets into one training dataset
|
||||
|
||||
|
||||
def main(data_dir: str, *, push_to_hub: bool = False):
|
||||
# Clean up any existing dataset in the output directory
|
||||
output_path = HF_LEROBOT_HOME / REPO_NAME
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
# Create LeRobot dataset, define features to store
|
||||
# OpenPi assumes that proprio is stored in `state` and actions in `action`
|
||||
# LeRobot assumes that dtype of image data is `image`
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=REPO_NAME,
|
||||
robot_type="panda",
|
||||
fps=10,
|
||||
features={
|
||||
"image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"wrist_image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": ["state"],
|
||||
},
|
||||
"actions": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": ["actions"],
|
||||
},
|
||||
},
|
||||
image_writer_threads=10,
|
||||
image_writer_processes=5,
|
||||
)
|
||||
|
||||
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
|
||||
# You can modify this for your own data format
|
||||
for raw_dataset_name in RAW_DATASET_NAMES:
|
||||
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
|
||||
for episode in raw_dataset:
|
||||
for step in episode["steps"].as_numpy_iterator():
|
||||
dataset.add_frame(
|
||||
{
|
||||
"image": step["observation"]["image"],
|
||||
"wrist_image": step["observation"]["wrist_image"],
|
||||
"state": step["observation"]["state"],
|
||||
"actions": step["action"],
|
||||
"task": step["language_instruction"].decode(),
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Optionally push to the Hugging Face Hub
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(
|
||||
tags=["libero", "panda", "rlds"],
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
219
policy/openpi-InternData-A1/examples/libero/main.py
Normal file
219
policy/openpi-InternData-A1/examples/libero/main.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
from libero.libero import benchmark
|
||||
from libero.libero import get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
|
||||
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
resize_size: int = 224
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# LIBERO environment-specific parameters
|
||||
#################################################################################################################
|
||||
task_suite_name: str = (
|
||||
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
||||
)
|
||||
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
|
||||
num_trials_per_task: int = 50 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/libero/videos" # Path to save videos
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def eval_libero(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize LIBERO task suite
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[args.task_suite_name]()
|
||||
num_tasks_in_suite = task_suite.n_tasks
|
||||
logging.info(f"Task suite: {args.task_suite_name}")
|
||||
|
||||
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.task_suite_name == "libero_spatial":
|
||||
max_steps = 220 # longest training demo has 193 steps
|
||||
elif args.task_suite_name == "libero_object":
|
||||
max_steps = 280 # longest training demo has 254 steps
|
||||
elif args.task_suite_name == "libero_goal":
|
||||
max_steps = 300 # longest training demo has 270 steps
|
||||
elif args.task_suite_name == "libero_10":
|
||||
max_steps = 520 # longest training demo has 505 steps
|
||||
elif args.task_suite_name == "libero_90":
|
||||
max_steps = 400 # longest training demo has 373 steps
|
||||
else:
|
||||
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation
|
||||
total_episodes, total_successes = 0, 0
|
||||
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
||||
# Get task
|
||||
task = task_suite.get_task(task_id)
|
||||
|
||||
# Get default LIBERO initial states
|
||||
initial_states = task_suite.get_task_init_states(task_id)
|
||||
|
||||
# Initialize LIBERO environment and task description
|
||||
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
|
||||
|
||||
# Start episodes
|
||||
task_episodes, task_successes = 0, 0
|
||||
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
|
||||
logging.info(f"\nTask: {task_description}")
|
||||
|
||||
# Reset environment
|
||||
env.reset()
|
||||
action_plan = collections.deque()
|
||||
|
||||
# Set initial states
|
||||
obs = env.set_init_state(initial_states[episode_idx])
|
||||
|
||||
# Setup
|
||||
t = 0
|
||||
replay_images = []
|
||||
|
||||
logging.info(f"Starting episode {task_episodes+1}...")
|
||||
while t < max_steps + args.num_steps_wait:
|
||||
try:
|
||||
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
||||
# and we need to wait for them to fall
|
||||
if t < args.num_steps_wait:
|
||||
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# Get preprocessed image
|
||||
# IMPORTANT: rotate 180 degrees to match train preprocessing
|
||||
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
|
||||
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
|
||||
)
|
||||
wrist_img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
|
||||
)
|
||||
|
||||
# Save preprocessed image for replay video
|
||||
replay_images.append(img)
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/image": img,
|
||||
"observation/wrist_image": wrist_img,
|
||||
"observation/state": np.concatenate(
|
||||
(
|
||||
obs["robot0_eef_pos"],
|
||||
_quat2axisangle(obs["robot0_eef_quat"]),
|
||||
obs["robot0_gripper_qpos"],
|
||||
)
|
||||
),
|
||||
"prompt": str(task_description),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Execute action in environment
|
||||
obs, reward, done, info = env.step(action.tolist())
|
||||
if done:
|
||||
task_successes += 1
|
||||
total_successes += 1
|
||||
break
|
||||
t += 1
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Caught exception: {e}")
|
||||
break
|
||||
|
||||
task_episodes += 1
|
||||
total_episodes += 1
|
||||
|
||||
# Save a replay video of the episode
|
||||
suffix = "success" if done else "failure"
|
||||
task_segment = task_description.replace(" ", "_")
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
|
||||
[np.asarray(x) for x in replay_images],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
# Log current results
|
||||
logging.info(f"Success: {done}")
|
||||
logging.info(f"# episodes completed so far: {total_episodes}")
|
||||
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
||||
|
||||
# Log final results
|
||||
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
||||
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
|
||||
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
logging.info(f"Total episodes: {total_episodes}")
|
||||
|
||||
|
||||
def _get_libero_env(task, resolution, seed):
|
||||
"""Initializes and returns the LIBERO environment, along with the task description."""
|
||||
task_description = task.language
|
||||
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
|
||||
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
||||
return env, task_description
|
||||
|
||||
|
||||
def _quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(eval_libero)
|
||||
11
policy/openpi-InternData-A1/examples/libero/requirements.in
Normal file
11
policy/openpi-InternData-A1/examples/libero/requirements.in
Normal file
@@ -0,0 +1,11 @@
|
||||
imageio[ffmpeg]
|
||||
numpy==1.22.4
|
||||
tqdm
|
||||
tyro
|
||||
PyYaml
|
||||
opencv-python==4.6.0.66
|
||||
torch==1.11.0+cu113
|
||||
torchvision==0.12.0+cu113
|
||||
torchaudio==0.11.0+cu113
|
||||
robosuite==1.4.1
|
||||
matplotlib==3.5.3
|
||||
136
policy/openpi-InternData-A1/examples/libero/requirements.txt
Normal file
136
policy/openpi-InternData-A1/examples/libero/requirements.txt
Normal file
@@ -0,0 +1,136 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
|
||||
absl-py==2.1.0
|
||||
# via mujoco
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
eval-type-backport==0.2.0
|
||||
# via tyro
|
||||
evdev==1.7.1
|
||||
# via pynput
|
||||
fonttools==4.55.3
|
||||
# via matplotlib
|
||||
glfw==1.12.0
|
||||
# via mujoco
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.35.1
|
||||
# via -r examples/libero/requirements.in
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
importlib-metadata==8.5.0
|
||||
# via typeguard
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
llvmlite==0.36.0
|
||||
# via numba
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.5.3
|
||||
# via -r examples/libero/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mujoco==3.2.3
|
||||
# via robosuite
|
||||
numba==0.53.1
|
||||
# via robosuite
|
||||
numpy==1.22.4
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# imageio
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# robosuite
|
||||
# scipy
|
||||
# torchvision
|
||||
opencv-python==4.6.0.66
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# robosuite
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
# robosuite
|
||||
# torchvision
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pynput==1.7.7
|
||||
# via robosuite
|
||||
pyopengl==3.1.7
|
||||
# via mujoco
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pyyaml==6.0.2
|
||||
# via -r examples/libero/requirements.in
|
||||
requests==2.32.3
|
||||
# via torchvision
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
robosuite==1.4.1
|
||||
# via -r examples/libero/requirements.in
|
||||
scipy==1.10.1
|
||||
# via robosuite
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# imageio-ffmpeg
|
||||
# numba
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
termcolor==2.4.0
|
||||
# via robosuite
|
||||
torch==1.11.0+cu113
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# torchaudio
|
||||
# torchvision
|
||||
torchaudio==0.11.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
torchvision==0.12.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
tqdm==4.67.1
|
||||
# via -r examples/libero/requirements.in
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# torch
|
||||
# torchvision
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/libero/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
zipp==3.20.2
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
134
policy/openpi-InternData-A1/examples/policy_records.ipynb
Normal file
134
policy/openpi-InternData-A1/examples/policy_records.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"record_path = pathlib.Path(\"../policy_records\")\n",
|
||||
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
|
||||
"\n",
|
||||
"records = []\n",
|
||||
"for i in range(num_steps):\n",
|
||||
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
|
||||
" records.append(record)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"length of records\", len(records))\n",
|
||||
"print(\"keys in records\", records[0].keys())\n",
|
||||
"\n",
|
||||
"for k in records[0]:\n",
|
||||
" print(f\"{k} shape: {records[0][k].shape}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_image(step: int, idx: int = 0):\n",
|
||||
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
|
||||
" return img[idx].transpose(1, 2, 0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_image(step: int, idx_lst: list[int]):\n",
|
||||
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
|
||||
" return Image.fromarray(np.hstack(imgs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(2):\n",
|
||||
" display(show_image(i, [0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_axis(name, axis):\n",
|
||||
" return np.array([record[name][axis] for record in records])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# qpos is [..., 14] of type float:\n",
|
||||
"# 0-5: left arm joint angles\n",
|
||||
"# 6: left arm gripper\n",
|
||||
"# 7-12: right arm joint angles\n",
|
||||
"# 13: right arm gripper\n",
|
||||
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_data():\n",
|
||||
" cur_dim = 0\n",
|
||||
" in_data = {}\n",
|
||||
" out_data = {}\n",
|
||||
" for name, dim_size in names:\n",
|
||||
" for i in range(dim_size):\n",
|
||||
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
|
||||
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
|
||||
" cur_dim += 1\n",
|
||||
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"in_data, out_data = make_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in in_data.columns:\n",
|
||||
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
|
||||
" data.plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
|
||||
30
policy/openpi-InternData-A1/examples/simple_client/README.md
Normal file
30
policy/openpi-InternData-A1/examples/simple_client/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --env DROID
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
187
policy/openpi-InternData-A1/examples/simple_client/main.py
Normal file
187
policy/openpi-InternData-A1/examples/simple_client/main.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import polars as pl
|
||||
import rich
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Command line arguments."""
|
||||
|
||||
# Host and port to connect to the server.
|
||||
host: str = "0.0.0.0"
|
||||
# Port to connect to the server. If None, the server will use the default port.
|
||||
port: int | None = 8000
|
||||
# API key to use for the server.
|
||||
api_key: str | None = None
|
||||
# Number of steps to run the policy for.
|
||||
num_steps: int = 20
|
||||
# Path to save the timings to a parquet file. (e.g., timing.parquet)
|
||||
timing_file: pathlib.Path | None = None
|
||||
# Environment to run the policy in.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
|
||||
|
||||
class TimingRecorder:
|
||||
"""Records timing measurements for different keys."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._timings: dict[str, list[float]] = {}
|
||||
|
||||
def record(self, key: str, time_ms: float) -> None:
|
||||
"""Record a timing measurement for the given key."""
|
||||
if key not in self._timings:
|
||||
self._timings[key] = []
|
||||
self._timings[key].append(time_ms)
|
||||
|
||||
def get_stats(self, key: str) -> dict[str, float]:
|
||||
"""Get statistics for the given key."""
|
||||
times = self._timings[key]
|
||||
return {
|
||||
"mean": float(np.mean(times)),
|
||||
"std": float(np.std(times)),
|
||||
"p25": float(np.quantile(times, 0.25)),
|
||||
"p50": float(np.quantile(times, 0.50)),
|
||||
"p75": float(np.quantile(times, 0.75)),
|
||||
"p90": float(np.quantile(times, 0.90)),
|
||||
"p95": float(np.quantile(times, 0.95)),
|
||||
"p99": float(np.quantile(times, 0.99)),
|
||||
}
|
||||
|
||||
def print_all_stats(self) -> None:
|
||||
"""Print statistics for all keys in a concise format."""
|
||||
|
||||
table = rich.table.Table(
|
||||
title="[bold blue]Timing Statistics[/bold blue]",
|
||||
show_header=True,
|
||||
header_style="bold white",
|
||||
border_style="blue",
|
||||
title_justify="center",
|
||||
)
|
||||
|
||||
# Add metric column with custom styling
|
||||
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
|
||||
|
||||
# Add statistical columns with consistent styling
|
||||
stat_columns = [
|
||||
("Mean", "yellow", "mean"),
|
||||
("Std", "yellow", "std"),
|
||||
("P25", "magenta", "p25"),
|
||||
("P50", "magenta", "p50"),
|
||||
("P75", "magenta", "p75"),
|
||||
("P90", "magenta", "p90"),
|
||||
("P95", "magenta", "p95"),
|
||||
("P99", "magenta", "p99"),
|
||||
]
|
||||
|
||||
for name, style, _ in stat_columns:
|
||||
table.add_column(name, justify="right", style=style, no_wrap=True)
|
||||
|
||||
# Add rows for each metric with formatted values
|
||||
for key in sorted(self._timings.keys()):
|
||||
stats = self.get_stats(key)
|
||||
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
|
||||
table.add_row(key, *values)
|
||||
|
||||
# Print with custom console settings
|
||||
console = rich.console.Console(width=None, highlight=True)
|
||||
console.print(table)
|
||||
|
||||
def write_parquet(self, path: pathlib.Path) -> None:
|
||||
"""Save the timings to a parquet file."""
|
||||
logger.info(f"Writing timings to {path}")
|
||||
frame = pl.DataFrame(self._timings)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
frame.write_parquet(path)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
EnvMode.ALOHA: _random_observation_aloha,
|
||||
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
||||
EnvMode.DROID: _random_observation_droid,
|
||||
EnvMode.LIBERO: _random_observation_libero,
|
||||
}[args.env]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
api_key=args.api_key,
|
||||
)
|
||||
logger.info(f"Server metadata: {policy.get_server_metadata()}")
|
||||
|
||||
# Send a few observations to make sure the model is loaded.
|
||||
for _ in range(2):
|
||||
policy.infer(obs_fn())
|
||||
|
||||
timing_recorder = TimingRecorder()
|
||||
|
||||
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
|
||||
inference_start = time.time()
|
||||
action = policy.infer(obs_fn())
|
||||
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
|
||||
for key, value in action.get("server_timing", {}).items():
|
||||
timing_recorder.record(f"server_{key}", value)
|
||||
for key, value in action.get("policy_timing", {}).items():
|
||||
timing_recorder.record(f"policy_{key}", value)
|
||||
|
||||
timing_recorder.print_all_stats()
|
||||
|
||||
if args.timing_file is not None:
|
||||
timing_recorder.write_parquet(args.timing_file)
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main(tyro.cli(Args))
|
||||
@@ -0,0 +1,5 @@
|
||||
numpy>=1.22.4,<2.0.0
|
||||
rich
|
||||
tqdm
|
||||
tyro
|
||||
polars
|
||||
@@ -0,0 +1,30 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.26.4
|
||||
# via -r examples/simple_client/requirements.in
|
||||
polars==1.30.0
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
rich==14.0.0
|
||||
# via
|
||||
# -r examples/simple_client/requirements.in
|
||||
# tyro
|
||||
shtab==1.7.2
|
||||
# via tyro
|
||||
tqdm==4.67.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
typeguard==4.4.2
|
||||
# via tyro
|
||||
typing-extensions==4.13.2
|
||||
# via
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.22
|
||||
# via -r examples/simple_client/requirements.in
|
||||
142
policy/openpi-InternData-A1/examples/ur5/README.md
Normal file
142
policy/openpi-InternData-A1/examples/ur5/README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# UR5 Example
|
||||
|
||||
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
|
||||
|
||||
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Inputs(transforms.DataTransformFn):
|
||||
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# First, concatenate the joints and gripper into the state vector.
|
||||
state = np.concatenate([data["joints"], data["gripper"]])
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference.
|
||||
base_image = _parse_image(data["base_rgb"])
|
||||
wrist_image = _parse_image(data["wrist_rgb"])
|
||||
|
||||
# Create inputs dict.
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": base_image,
|
||||
"left_wrist_0_rgb": wrist_image,
|
||||
# Since there is no right wrist, replace with zeros
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.True_,
|
||||
"left_wrist_0_rgb": np.True_,
|
||||
# Since the "slot" for the right wrist is not used, this mask is set
|
||||
# to False
|
||||
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
|
||||
},
|
||||
}
|
||||
|
||||
if "actions" in data:
|
||||
inputs["actions"] = data["actions"]
|
||||
|
||||
# Pass the prompt (aka language instruction) to the model.
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Outputs(transforms.DataTransformFn):
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
|
||||
return {"actions": np.asarray(data["actions"][:, :7])}
|
||||
|
||||
```
|
||||
|
||||
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LeRobotUR5DataConfig(DataConfigFactory):
|
||||
|
||||
@override
|
||||
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
||||
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
|
||||
repack_transform = _transforms.Group(
|
||||
inputs=[
|
||||
_transforms.RepackTransform(
|
||||
{
|
||||
"base_rgb": "image",
|
||||
"wrist_rgb": "wrist_image",
|
||||
"joints": "joints",
|
||||
"gripper": "gripper",
|
||||
"prompt": "prompt",
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# These transforms are the ones we wrote earlier.
|
||||
data_transforms = _transforms.Group(
|
||||
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
|
||||
outputs=[UR5Outputs()],
|
||||
)
|
||||
|
||||
# Convert absolute actions to delta actions.
|
||||
# By convention, we do not convert the gripper action (7th dimension).
|
||||
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
||||
data_transforms = data_transforms.push(
|
||||
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
||||
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
||||
)
|
||||
|
||||
# Model transforms include things like tokenizing the prompt and action targets
|
||||
# You do not need to change anything here for your own dataset.
|
||||
model_transforms = ModelTransformFactory()(model_config)
|
||||
|
||||
# We return all data transforms for training and inference. No need to change anything here.
|
||||
return dataclasses.replace(
|
||||
self.create_base_config(assets_dirs),
|
||||
repack_transforms=repack_transform,
|
||||
data_transforms=data_transforms,
|
||||
model_transforms=model_transforms,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name="pi0_ur5",
|
||||
model=pi0.Pi0Config(),
|
||||
data=LeRobotUR5DataConfig(
|
||||
repo_id="your_username/ur5_dataset",
|
||||
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
|
||||
# Reloading normalization stats can help transfer pre-trained models to new environments.
|
||||
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
|
||||
assets=AssetsConfig(
|
||||
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
||||
asset_id="ur5e",
|
||||
),
|
||||
base_config=DataConfig(
|
||||
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
||||
# ``task`` field in the LeRobot dataset. The recommended setting is True.
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
# Load the pi0 base model checkpoint.
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
||||
num_train_steps=30_000,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user