7 Commits

Author SHA1 Message Date
Karl Pertsch
b84cc75031 add binning jointpos 2025-04-25 05:28:23 +00:00
Karl Pertsch
c23bc86a0a load droid sim eval policies without credentials (#440)
small change to enable loading from the openpi sim eval bucket without credentials (for joint pos policies)
2025-04-17 15:39:53 -04:00
Arhan Jain
fe5d5580a4 load droid sim eval policies without credentials 2025-04-17 12:26:06 -07:00
Karl Pertsch
650b02e4ca add diffusion jointpos policy 2025-04-17 13:19:48 +00:00
Karl Pertsch
e43516e719 add diffusion droid policy 2025-04-14 20:15:23 +00:00
Karl Pertsch
20d63d47b7 additional policy 2025-04-14 19:18:09 +00:00
Karl Pertsch
1ce9ffe134 add DROID policies 2025-04-14 18:42:57 +00:00
139 changed files with 833 additions and 2175 deletions

0
.dockerignore Executable file → Normal file
View File

0
.github/CODEOWNERS vendored Executable file → Normal file
View File

0
.github/workflows/pre-commit.yml vendored Executable file → Normal file
View File

0
.github/workflows/test.yml vendored Executable file → Normal file
View File

2
.gitignore vendored Executable file → Normal file
View File

@@ -12,8 +12,6 @@ __pycache__/
# C extensions
*.so
third-party/*
# Distribution / packaging
.Python
build/

0
.gitmodules vendored Executable file → Normal file
View File

0
.pre-commit-config.yaml Executable file → Normal file
View File

0
.python-version Executable file → Normal file
View File

0
.vscode/settings.json vendored Executable file → Normal file
View File

0
CONTRIBUTING.md Executable file → Normal file
View File

0
LICENSE Executable file → Normal file
View File

0
README.md Executable file → Normal file
View File

0
docs/docker.md Executable file → Normal file
View File

0
docs/norm_stats.md Executable file → Normal file
View File

0
docs/remote_inference.md Executable file → Normal file
View File

0
examples/aloha_real/Dockerfile Executable file → Normal file
View File

0
examples/aloha_real/README.md Executable file → Normal file
View File

0
examples/aloha_real/compose.yml Executable file → Normal file
View File

0
examples/aloha_real/constants.py Executable file → Normal file
View File

0
examples/aloha_real/convert_aloha_data_to_lerobot.py Executable file → Normal file
View File

0
examples/aloha_real/env.py Executable file → Normal file
View File

0
examples/aloha_real/main.py Executable file → Normal file
View File

6
examples/aloha_real/real_env.py Executable file → Normal file
View File

@@ -49,11 +49,7 @@ class RealEnv:
init_node=init_node,
)
self.puppet_bot_right = InterbotixManipulatorXS(
robot_model="vx300s",
group_name="arm",
gripper_name="gripper",
robot_name="puppet_right",
init_node=False
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
)
if setup_robots:
self.setup_robots()

0
examples/aloha_real/requirements.in Executable file → Normal file
View File

0
examples/aloha_real/requirements.txt Executable file → Normal file
View File

0
examples/aloha_real/robot_utils.py Executable file → Normal file
View File

0
examples/aloha_real/video_display.py Executable file → Normal file
View File

View File

@@ -1,70 +0,0 @@
# Dockerfile for the Aloha real environment.
# Build the container:
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
SHELL ["/bin/bash", "-c"]
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
cmake \
curl \
libffi-dev \
python3-rosdep \
python3-rosinstall \
python3-rosinstall-generator \
whiptail \
git \
wget \
openssh-client \
ros-noetic-cv-bridge \
ros-noetic-usb-cam \
ros-noetic-realsense2-camera \
keyboard-configuration
WORKDIR /root
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
RUN chmod +x xsarm_amd64_install.sh
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
# Install python 3.10 because this ROS image comes with 3.8
RUN mkdir /python && \
cd /python && \
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
tar -zxvf Python-3.10.14.tgz && \
cd Python-3.10.14 && \
ls -lhR && \
./configure --enable-optimizations && \
make install && \
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
cd ~ && rm -rf /python && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
ENV UV_HTTP_TIMEOUT=120
ENV UV_LINK_MODE=copy
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
WORKDIR /app
# Create an entrypoint script to run the setup commands, followed by the command passed in.
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
#!/bin/bash
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
EOF
RUN chmod +x /usr/local/bin/entrypoint.sh
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["python3", "/app/examples/aloha_real/main.py"]

View File

@@ -1,126 +0,0 @@
# Run Aloha (Real Robot)
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
## Prerequisites
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
## With Docker
```bash
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
docker compose -f examples/aloha_real/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_real/.venv
source examples/aloha_real/.venv/bin/activate
uv pip sync examples/aloha_real/requirements.txt
uv pip install -e packages/openpi-client
# Run the robot
python -m examples.aloha_real.main
```
Terminal window 2:
```bash
roslaunch aloha ros_nodes.launch
```
Terminal window 3:
```bash
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
```
## **ALOHA Checkpoint Guide**
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
While weve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects weve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
---
### **Toast Task**
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_base`
- **Prompt**: "take the toast out of the toaster"
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
- **Object Distribution**:
- Works on both real toast and rubber fake toast
- Compatible with standard 2-slice toasters
- Works with plates of varying colors
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
- The toaster should be positioned in the top-left quadrant of the workspace.
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
- The plate should be placed roughly in the lower-center of the workspace.
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
### **Towel Task**
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_towel`
- **Prompt**: "fold the towel"
- **Object Distribution**:
- Works on towels of varying solid colors
- Performance is worse on heavily textured or striped towels
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
- The towel should be flattened and roughly centered on the table.
- Choose a towel that does not blend in with the table surface.
### **Tupperware Task**
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_tupperware`
- **Prompt**: "open the tupperware and put the food on the plate"
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
- **Object Distribution**:
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
- The policy has seen plates of varying solid colors.
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
- Positioning:
- Tupperware should be on the left.
- Plate should be on the right or bottom.
- The tupperware flap should point toward the plate.
## Training on your own Aloha dataset
1. Convert the dataset to the LeRobot dataset v2.0 format.
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
2. Define a training config that uses the custom dataset.
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoints asset directory within the AssetsConfig.

View File

@@ -1,66 +0,0 @@
# Run with:
# docker compose -f examples/aloha_real/compose.yml up --build
services:
runtime:
image: aloha_real
depends_on:
- aloha_ros_nodes
- ros_master
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
aloha_ros_nodes:
image: aloha_real
depends_on:
- ros_master
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- /dev:/dev
command: roslaunch --wait aloha ros_nodes.launch
ros_master:
image: ros:noetic-robot
network_mode: host
privileged: true
command:
- roscore
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -1,71 +0,0 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
### Task parameters
### ALOHA fixed constants
DT = 0.001
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
############################ Helper functions ############################
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
MASTER_POS2JOINT = (
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
PUPPET_POS2JOINT = (
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2

View File

@@ -1,272 +0,0 @@
"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""
import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
import tyro
@dataclasses.dataclass(frozen=True)
class DatasetConfig:
use_videos: bool = True
tolerance_s: float = 0.0001
image_writer_processes: int = 10
image_writer_threads: int = 5
video_backend: str | None = None
DEFAULT_DATASET_CONFIG = DatasetConfig()
def create_empty_dataset(
repo_id: str,
robot_type: str,
mode: Literal["video", "image"] = "video",
*,
has_velocity: bool = False,
has_effort: bool = False,
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
motors = [
"right_waist",
"right_shoulder",
"right_elbow",
"right_forearm_roll",
"right_wrist_angle",
"right_wrist_rotate",
"right_gripper",
"left_waist",
"left_shoulder",
"left_elbow",
"left_forearm_roll",
"left_wrist_angle",
"left_wrist_rotate",
"left_gripper",
]
cameras = [
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
]
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
}
if has_velocity:
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
if has_effort:
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
for cam in cameras:
features[f"observation.images.{cam}"] = {
"dtype": mode,
"shape": (3, 480, 640),
"names": [
"channels",
"height",
"width",
],
}
if Path(LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
return LeRobotDataset.create(
repo_id=repo_id,
fps=50,
robot_type=robot_type,
features=features,
use_videos=dataset_config.use_videos,
tolerance_s=dataset_config.tolerance_s,
image_writer_processes=dataset_config.image_writer_processes,
image_writer_threads=dataset_config.image_writer_threads,
video_backend=dataset_config.video_backend,
)
def get_cameras(hdf5_files: list[Path]) -> list[str]:
with h5py.File(hdf5_files[0], "r") as ep:
# ignore depth channel, not currently handled
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
def has_velocity(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/qvel" in ep
def has_effort(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/effort" in ep
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
imgs_per_cam = {}
for camera in cameras:
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
if uncompressed:
# load all images in RAM
imgs_array = ep[f"/observations/images/{camera}"][:]
else:
import cv2
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array
return imgs_per_cam
def load_raw_episode_data(
ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
effort = None
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
imgs_per_cam = load_raw_images_per_camera(
ep,
[
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
],
)
return imgs_per_cam, state, action, velocity, effort
def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode(task=task)
return dataset
def port_aloha(
raw_dir: Path,
repo_id: str,
raw_repo_id: str | None = None,
task: str = "DEBUG",
*,
episodes: list[int] | None = None,
push_to_hub: bool = True,
is_mobile: bool = False,
mode: Literal["video", "image"] = "image",
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
if raw_repo_id is None:
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
download_raw(raw_dir, repo_id=raw_repo_id)
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if is_mobile else "aloha",
mode=mode,
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
dataset_config=dataset_config,
)
dataset = populate_dataset(
dataset,
hdf5_files,
task=task,
episodes=episodes,
)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
tyro.cli(port_aloha)

View File

@@ -1,57 +0,0 @@
from typing import List, Optional # noqa: UP035
import einops
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override
from examples.aloha_real import real_env as _real_env
class AlohaRealEnvironment(_environment.Environment):
"""An environment for an Aloha robot on real hardware."""
def __init__(
self,
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
render_height: int = 224,
render_width: int = 224,
) -> None:
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
self._render_height = render_height
self._render_width = render_width
self._ts = None
@override
def reset(self) -> None:
self._ts = self._env.reset()
@override
def is_episode_complete(self) -> bool:
return False
@override
def get_observation(self) -> dict:
if self._ts is None:
raise RuntimeError("Timestep is not set. Call reset() first.")
obs = self._ts.observation
for k in list(obs["images"].keys()):
if "_depth" in k:
del obs["images"][k]
for cam_name in obs["images"]:
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
)
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
return {
"state": obs["qpos"],
"images": obs["images"],
}
@override
def apply_action(self, action: dict) -> None:
self._ts = self._env.step(action["actions"])

View File

@@ -1,51 +0,0 @@
import dataclasses
import logging
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import tyro
from examples.aloha_real_lyt import env as _env
@dataclasses.dataclass
class Args:
host: str = "172.20.103.171"
port: int = 8090
action_horizon: int = 25
num_episodes: int = 1
max_episode_steps: int = 1000
def main(args: Args) -> None:
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
metadata = ws_client_policy.get_server_metadata()
runtime = _runtime.Runtime(
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=ws_client_policy,
action_horizon=args.action_horizon,
)
),
subscribers=[],
max_hz=50,
num_episodes=args.num_episodes,
max_episode_steps=args.max_episode_steps,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -1,171 +0,0 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
import collections
import time
from typing import Optional, List
import dm_env
from interbotix_xs_modules.arm import InterbotixManipulatorXS
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
from examples.aloha_real import constants
from examples.aloha_real import robot_utils
# This is the reset position that is used by the standard Aloha runtime.
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
class RealEnv:
"""
Environment for real robot bi-manual manipulation
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
"""
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
# reset_position = START_ARM_POSE[:6]
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
self.puppet_bot_left = InterbotixManipulatorXS(
robot_model="vx300s",
group_name="arm",
gripper_name="gripper",
robot_name="puppet_left",
init_node=init_node,
)
self.puppet_bot_right = InterbotixManipulatorXS(
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
)
if setup_robots:
self.setup_robots()
self.recorder_left = robot_utils.Recorder("left", init_node=False)
self.recorder_right = robot_utils.Recorder("right", init_node=False)
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
self.gripper_command = JointSingleCommand(name="gripper")
def setup_robots(self):
robot_utils.setup_puppet_bot(self.puppet_bot_left)
robot_utils.setup_puppet_bot(self.puppet_bot_right)
def get_qpos(self):
left_qpos_raw = self.recorder_left.qpos
right_qpos_raw = self.recorder_right.qpos
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
] # this is position not joint
right_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
] # this is position not joint
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
def get_qvel(self):
left_qvel_raw = self.recorder_left.qvel
right_qvel_raw = self.recorder_right.qvel
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
def get_effort(self):
left_effort_raw = self.recorder_left.effort
right_effort_raw = self.recorder_right.effort
left_robot_effort = left_effort_raw[:7]
right_robot_effort = right_effort_raw[:7]
return np.concatenate([left_robot_effort, right_robot_effort])
def get_images(self):
return self.image_recorder.get_images()
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
self.gripper_command.cmd = left_gripper_desired_joint
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
right_gripper_desired_pos_normalized
)
self.gripper_command.cmd = right_gripper_desired_joint
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
def _reset_joints(self):
robot_utils.move_arms(
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
)
def _reset_gripper(self):
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
)
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
)
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
return obs
def get_reward(self):
return 0
def reset(self, *, fake=False):
if not fake:
# Reboot puppet robot gripper motors
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
self._reset_joints()
self._reset_gripper()
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def step(self, action):
state_len = int(len(action) / 2)
left_action = action[:state_len]
right_action = action[state_len:]
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
self.set_gripper_pose(left_action[-1], right_action[-1])
time.sleep(constants.DT)
return dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def get_action(master_bot_left, master_bot_right):
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
# Arm actions
action[:6] = master_bot_left.dxl.joint_states.position[:6]
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
# Gripper actions
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
return action
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)

View File

@@ -1,18 +0,0 @@
Pillow
dm_control
einops
h5py
matplotlib
modern_robotics
msgpack
numpy
opencv-python
packaging
pexpect
pyquaternion
pyrealsense2
pyyaml
requests
rospkg
tyro
websockets

View File

@@ -1,156 +0,0 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
catkin-pkg==1.0.0
# via rospkg
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
contourpy==1.1.1
# via matplotlib
cycler==0.12.1
# via matplotlib
distro==1.9.0
# via rospkg
dm-control==1.0.23
# via -r examples/aloha_real/requirements.in
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
docutils==0.20.1
# via catkin-pkg
einops==0.8.0
# via -r examples/aloha_real/requirements.in
etils==1.3.0
# via mujoco
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
h5py==3.11.0
# via -r examples/aloha_real/requirements.in
idna==3.10
# via requests
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.7.5
# via -r examples/aloha_real/requirements.in
mdurl==0.1.2
# via markdown-it-py
modern-robotics==1.1.1
# via -r examples/aloha_real/requirements.in
msgpack==1.1.0
# via -r examples/aloha_real/requirements.in
mujoco==3.2.3
# via dm-control
numpy==1.24.4
# via
# -r examples/aloha_real/requirements.in
# contourpy
# dm-control
# dm-env
# h5py
# labmaze
# matplotlib
# modern-robotics
# mujoco
# opencv-python
# pyquaternion
# scipy
opencv-python==4.10.0.84
# via -r examples/aloha_real/requirements.in
packaging==24.2
# via
# -r examples/aloha_real/requirements.in
# matplotlib
pexpect==4.9.0
# via -r examples/aloha_real/requirements.in
pillow==10.4.0
# via
# -r examples/aloha_real/requirements.in
# matplotlib
protobuf==5.29.1
# via dm-control
ptyprocess==0.7.0
# via pexpect
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.1.4
# via
# catkin-pkg
# dm-control
# matplotlib
pyquaternion==0.9.9
# via -r examples/aloha_real/requirements.in
pyrealsense2==2.55.1.6486
# via -r examples/aloha_real/requirements.in
python-dateutil==2.9.0.post0
# via
# catkin-pkg
# matplotlib
pyyaml==6.0.2
# via
# -r examples/aloha_real/requirements.in
# rospkg
requests==2.32.3
# via
# -r examples/aloha_real/requirements.in
# dm-control
rich==13.9.4
# via tyro
rospkg==1.5.1
# via -r examples/aloha_real/requirements.in
scipy==1.10.1
# via dm-control
setuptools==75.3.0
# via
# catkin-pkg
# dm-control
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_real/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_real/requirements.in
zipp==3.20.2
# via etils

View File

@@ -1,275 +0,0 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
from collections import deque
import datetime
import json
import time
from aloha.msg import RGBGrayscaleImage
from cv_bridge import CvBridge
from interbotix_xs_msgs.msg import JointGroupCommand
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
import rospy
from sensor_msgs.msg import JointState
from examples.aloha_real import constants
class ImageRecorder:
def __init__(self, init_node=True, is_debug=False):
self.is_debug = is_debug
self.bridge = CvBridge()
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
if init_node:
rospy.init_node("image_recorder", anonymous=True)
for cam_name in self.camera_names:
setattr(self, f"{cam_name}_rgb_image", None)
setattr(self, f"{cam_name}_depth_image", None)
setattr(self, f"{cam_name}_timestamp", 0.0)
if cam_name == "cam_high":
callback_func = self.image_cb_cam_high
elif cam_name == "cam_low":
callback_func = self.image_cb_cam_low
elif cam_name == "cam_left_wrist":
callback_func = self.image_cb_cam_left_wrist
elif cam_name == "cam_right_wrist":
callback_func = self.image_cb_cam_right_wrist
else:
raise NotImplementedError
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
if self.is_debug:
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
time.sleep(0.5)
def image_cb(self, cam_name, data):
setattr(
self,
f"{cam_name}_rgb_image",
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
)
# setattr(
# self,
# f"{cam_name}_depth_image",
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
# )
setattr(
self,
f"{cam_name}_timestamp",
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
)
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
if self.is_debug:
getattr(self, f"{cam_name}_timestamps").append(
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
)
def image_cb_cam_high(self, data):
cam_name = "cam_high"
return self.image_cb(cam_name, data)
def image_cb_cam_low(self, data):
cam_name = "cam_low"
return self.image_cb(cam_name, data)
def image_cb_cam_left_wrist(self, data):
cam_name = "cam_left_wrist"
return self.image_cb(cam_name, data)
def image_cb_cam_right_wrist(self, data):
cam_name = "cam_right_wrist"
return self.image_cb(cam_name, data)
def get_images(self):
image_dict = {}
for cam_name in self.camera_names:
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
time.sleep(0.00001)
rgb_image = getattr(self, f"{cam_name}_rgb_image")
depth_image = getattr(self, f"{cam_name}_depth_image")
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
image_dict[cam_name] = rgb_image
image_dict[f"{cam_name}_depth"] = depth_image
return image_dict
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
for cam_name in self.camera_names:
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
print(f"{cam_name} {image_freq=:.2f}")
print()
class Recorder:
def __init__(self, side, init_node=True, is_debug=False):
self.secs = None
self.nsecs = None
self.qpos = None
self.effort = None
self.arm_command = None
self.gripper_command = None
self.is_debug = is_debug
if init_node:
rospy.init_node("recorder", anonymous=True)
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_group",
JointGroupCommand,
self.puppet_arm_commands_cb,
)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_single",
JointSingleCommand,
self.puppet_gripper_commands_cb,
)
if self.is_debug:
self.joint_timestamps = deque(maxlen=50)
self.arm_command_timestamps = deque(maxlen=50)
self.gripper_command_timestamps = deque(maxlen=50)
time.sleep(0.1)
def puppet_state_cb(self, data):
self.qpos = data.position
self.qvel = data.velocity
self.effort = data.effort
self.data = data
if self.is_debug:
self.joint_timestamps.append(time.time())
def puppet_arm_commands_cb(self, data):
self.arm_command = data.cmd
if self.is_debug:
self.arm_command_timestamps.append(time.time())
def puppet_gripper_commands_cb(self, data):
self.gripper_command = data.cmd
if self.is_debug:
self.gripper_command_timestamps.append(time.time())
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
joint_freq = 1 / dt_helper(self.joint_timestamps)
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
def get_arm_joint_positions(bot):
return bot.arm.core.joint_states.position[:6]
def get_arm_gripper_positions(bot):
return bot.gripper.core.joint_states.position[6]
def move_arms(bot_list, target_pose_list, move_time=1):
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
for t in range(num_steps):
for bot_id, bot in enumerate(bot_list):
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
time.sleep(constants.DT)
def move_grippers(bot_list, target_pose_list, move_time):
print(f"Moving grippers to {target_pose_list=}")
gripper_command = JointSingleCommand(name="gripper")
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
for t in range(num_steps):
d = {}
for bot_id, bot in enumerate(bot_list):
gripper_command.cmd = traj_list[bot_id][t]
bot.gripper.core.pub_single.publish(gripper_command)
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
f.write(json.dumps(d) + "\n")
time.sleep(constants.DT)
def setup_puppet_bot(bot):
bot.dxl.robot_reboot_motors("single", "gripper", True)
bot.dxl.robot_set_operating_modes("group", "arm", "position")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_on(bot)
def setup_master_bot(bot):
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_off(bot)
def set_standard_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def set_low_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def torque_off(bot):
bot.dxl.robot_torque_enable("group", "arm", False)
bot.dxl.robot_torque_enable("single", "gripper", False)
def torque_on(bot):
bot.dxl.robot_torque_enable("group", "arm", True)
bot.dxl.robot_torque_enable("single", "gripper", True)
# for DAgger
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
print("\nSyncing!")
# activate master arms
torque_on(master_bot_left)
torque_on(master_bot_right)
# get puppet arm positions
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
# get puppet gripper positions
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
# move master arms to puppet positions
move_arms(
[master_bot_left, master_bot_right],
[puppet_left_qpos, puppet_right_qpos],
move_time=1,
)
# move master grippers to puppet positions
move_grippers(
[master_bot_left, master_bot_right],
[puppet_left_gripper, puppet_right_gripper],
move_time=1,
)

View File

@@ -1,36 +0,0 @@
import matplotlib.pyplot as plt
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoDisplay(_subscriber.Subscriber):
"""Displays video frames."""
def __init__(self) -> None:
self._ax: plt.Axes | None = None
self._plt_img: plt.Image | None = None
@override
def on_episode_start(self) -> None:
plt.ion()
self._ax = plt.subplot()
self._plt_img = None
@override
def on_step(self, observation: dict, action: dict) -> None:
assert self._ax is not None
im = observation["image"][0] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
if self._plt_img is None:
self._plt_img = self._ax.imshow(im)
else:
self._plt_img.set_data(im)
plt.pause(0.001)
@override
def on_episode_end(self) -> None:
plt.ioff()
plt.close()

0
examples/aloha_sim/Dockerfile Executable file → Normal file
View File

0
examples/aloha_sim/README.md Executable file → Normal file
View File

0
examples/aloha_sim/compose.yml Executable file → Normal file
View File

0
examples/aloha_sim/env.py Executable file → Normal file
View File

0
examples/aloha_sim/main.py Executable file → Normal file
View File

0
examples/aloha_sim/requirements.in Executable file → Normal file
View File

0
examples/aloha_sim/requirements.txt Executable file → Normal file
View File

0
examples/aloha_sim/saver.py Executable file → Normal file
View File

0
examples/droid/README.md Executable file → Normal file
View File

0
examples/droid/main.py Executable file → Normal file
View File

56
examples/inference.ipynb Executable file → Normal file
View File

@@ -6,8 +6,6 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['HF_ENDPOINT'] = \"https://hf-mirror.com\"\n",
"import dataclasses\n",
"\n",
"import jax\n",
@@ -20,13 +18,6 @@
"from openpi.training import data_loader as _data_loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
@@ -40,53 +31,10 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fa8d45bf6fe5420f8b152ff52794ee45",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0.00/11.2G [00:00<?, ?iB/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
"Some kwargs in processor config are unused and will not have any effect: action_dim, scale, time_horizon, vocab_size, min_token. \n",
"Some kwargs in processor config are unused and will not have any effect: action_dim, scale, time_horizon, vocab_size, min_token. \n"
]
},
{
"ename": "ValueError",
"evalue": "quantile stats must be provided if use_quantile_norm is True. Key actions is missing q01 or q99.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m checkpoint_dir \u001b[38;5;241m=\u001b[39m download\u001b[38;5;241m.\u001b[39mmaybe_download(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ms3://openpi-assets/checkpoints/pi0_base\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Create a trained policy.\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[43m_policy_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_trained_policy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m example \u001b[38;5;241m=\u001b[39m droid_policy\u001b[38;5;241m.\u001b[39mmake_droid_example()\n",
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/policies/policy_config.py:72\u001b[0m, in \u001b[0;36mcreate_trained_policy\u001b[0;34m(train_config, checkpoint_dir, repack_transforms, sample_kwargs, default_prompt, norm_stats)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAsset id is required to load norm stats.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 64\u001b[0m norm_stats \u001b[38;5;241m=\u001b[39m _checkpoints\u001b[38;5;241m.\u001b[39mload_norm_stats(checkpoint_dir \u001b[38;5;241m/\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massets\u001b[39m\u001b[38;5;124m\"\u001b[39m, data_config\u001b[38;5;241m.\u001b[39masset_id)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _policy\u001b[38;5;241m.\u001b[39mPolicy(\n\u001b[1;32m 67\u001b[0m model,\n\u001b[1;32m 68\u001b[0m transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 69\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 70\u001b[0m transforms\u001b[38;5;241m.\u001b[39mInjectDefaultPrompt(default_prompt),\n\u001b[1;32m 71\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[0;32m---> 72\u001b[0m \u001b[43mtransforms\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mNormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnorm_stats\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_quantiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_quantile_norm\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 73\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 74\u001b[0m ],\n\u001b[1;32m 75\u001b[0m output_transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 76\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 77\u001b[0m transforms\u001b[38;5;241m.\u001b[39mUnnormalize(norm_stats, use_quantiles\u001b[38;5;241m=\u001b[39mdata_config\u001b[38;5;241m.\u001b[39muse_quantile_norm),\n\u001b[1;32m 78\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 79\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 80\u001b[0m ],\n\u001b[1;32m 81\u001b[0m sample_kwargs\u001b[38;5;241m=\u001b[39msample_kwargs,\n\u001b[1;32m 82\u001b[0m metadata\u001b[38;5;241m=\u001b[39mtrain_config\u001b[38;5;241m.\u001b[39mpolicy_metadata,\n\u001b[1;32m 83\u001b[0m )\n",
"File \u001b[0;32m<string>:6\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, norm_stats, use_quantiles, strict)\u001b[0m\n",
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:124\u001b[0m, in \u001b[0;36mNormalize.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__post_init__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm_stats \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_quantiles:\n\u001b[0;32m--> 124\u001b[0m \u001b[43m_assert_quantile_stats\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_stats\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:431\u001b[0m, in \u001b[0;36m_assert_quantile_stats\u001b[0;34m(norm_stats)\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m flatten_dict(norm_stats)\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq01 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq99 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 431\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 432\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantile stats must be provided if use_quantile_norm is True. Key \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is missing q01 or q99.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 433\u001b[0m )\n",
"\u001b[0;31mValueError\u001b[0m: quantile stats must be provided if use_quantile_norm is True. Key actions is missing q01 or q99."
]
}
],
"outputs": [],
"source": [
"\n",
"config = _config.get_config(\"pi0_fast_droid\")\n",
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
"# checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_base\")\n",
"\n",
"# Create a trained policy.\n",
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
@@ -181,7 +129,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.12"
"version": "3.11.9"
}
},
"nbformat": 4,

0
examples/libero/Dockerfile Executable file → Normal file
View File

0
examples/libero/README.md Executable file → Normal file
View File

0
examples/libero/compose.yml Executable file → Normal file
View File

0
examples/libero/convert_libero_data_to_lerobot.py Executable file → Normal file
View File

0
examples/libero/main.py Executable file → Normal file
View File

0
examples/libero/requirements.in Executable file → Normal file
View File

0
examples/libero/requirements.txt Executable file → Normal file
View File

View File

@@ -1,32 +0,0 @@
# Dockerfile for the simple client.
# Build the container:
# docker build . -t simple_client -f examples/simple_client/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
FROM python:3.7-slim
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"

View File

@@ -1,30 +0,0 @@
# Simple Client
A minimal client that sends observations to the server and prints the inference rate.
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
```bash
uv run examples/simple_client/main.py --help
```
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/simple_client/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
uv run examples/simple_client/main.py --env DROID
```
Terminal window 2:
```bash
uv run scripts/serve_policy.py --env DROID
```

View File

@@ -1,353 +0,0 @@
import numpy as np
from einops import rearrange
from collections import deque
import rospy
from std_msgs.msg import Header
from geometry_msgs.msg import Twist
from sensor_msgs.msg import JointState, Image
from nav_msgs.msg import Odometry
from cv_bridge import CvBridge
import threading
class RosOperator:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.puppet_arm_left_publisher = None
self.puppet_arm_right_publisher = None
self.robot_base_publisher = None
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_lock = None
self.args = args
self.ctrl_state = False
self.ctrl_state_lock = threading.Lock()
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
self.puppet_arm_publish_lock = threading.Lock()
self.puppet_arm_publish_lock.acquire()
def puppet_arm_publish(self, left, right):
# 默认速度和力矩值
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
0.1527481079101562, -0.013187408447265625, -0.013187408447265625,
0.0, -0.03076934814453125, -0.3296699523925781, 0.43956756591797,
0.5208797454833984, -0.11868095397949219, 0.03956031799316406, 0.0]
# 修正位置
left[-1] *= 12
right[-1] *= 12
# 始终为正数小于0的裁剪为0
left[-1] = max(left[-1], 0)
right[-1] = max(right[-1], 0)
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[:7]
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right
joint_state_msg.velocity = last_velocity[7:]
joint_state_msg.effort = last_effort[7:]
self.puppet_arm_right_publisher.publish(joint_state_msg)
def robot_base_publish(self, vel):
vel_msg = Twist()
vel_msg.linear.x = vel[0]
vel_msg.linear.y = 0
vel_msg.linear.z = 0
vel_msg.angular.x = 0
vel_msg.angular.y = 0
vel_msg.angular.z = vel[1]
self.robot_base_publisher.publish(vel_msg)
def puppet_arm_publish_continuous(self, left, right):
rate = rospy.Rate(self.args.publish_rate)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
flag = True
step = 0
while flag and not rospy.is_shutdown():
if self.puppet_arm_publish_lock.acquire(False):
return
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
flag = False
for i in range(len(left)):
if left_diff[i] < self.args.arm_steps_length[i]:
left_arm[i] = left[i]
else:
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
flag = True
for i in range(len(right)):
if right_diff[i] < self.args.arm_steps_length[i]:
right_arm[i] = right[i]
else:
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
flag = True
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left_arm
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right_arm
self.puppet_arm_right_publisher.publish(joint_state_msg)
step += 1
print("puppet_arm_publish_continuous:", step)
rate.sleep()
def puppet_arm_publish_linear(self, left, right):
num_step = 100
rate = rospy.Rate(200)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
traj_left_list = np.linspace(left_arm, left, num_step)
traj_right_list = np.linspace(right_arm, right, num_step)
for i in range(len(traj_left_list)):
traj_left = traj_left_list[i]
traj_right = traj_right_list[i]
traj_left[-1] = left[-1]
traj_right[-1] = right[-1]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = traj_left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = traj_right
self.puppet_arm_right_publisher.publish(joint_state_msg)
rate.sleep()
def puppet_arm_publish_continuous_thread(self, left, right):
if self.puppet_arm_publish_thread is not None:
self.puppet_arm_publish_lock.release()
self.puppet_arm_publish_thread.join()
self.puppet_arm_publish_lock.acquire(False)
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
self.puppet_arm_publish_thread.start()
def get_frame(self):
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def ctrl_callback(self, msg):
self.ctrl_state_lock.acquire()
self.ctrl_state = msg.data
self.ctrl_state_lock.release()
def get_ctrl_state(self):
self.ctrl_state_lock.acquire()
state = self.ctrl_state
self.ctrl_state_lock.release()
return state
def init_ros(self):
rospy.init_node('joint_state_publisher', anonymous=True)
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
# rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
# self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
# self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
# self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.img_left_topic = '/camera_l/color/image_raw'
args.img_right_topic = '/camera_r/color/image_raw'
args.img_front_topic = '/camera_f/color/image_raw'
args.puppet_arm_left_cmd_topic = '/master/joint_left'
args.puppet_arm_right_cmd_topic = '/master/joint_right'
args.puppet_arm_left_topic = '/puppet/joint_left'
args.puppet_arm_right_topic = '/puppet/joint_right'
args.publish_rate = 30
args.use_robot_base = False
args.use_actions_interpolation = False
args.use_depth_image = False
a = RosOperator(args)
print(a)

View File

@@ -1,42 +0,0 @@
# Run with:
# docker compose -f examples/simple_client/compose.yml up --build
services:
runtime:
image: simple_client
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/simple_client/Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -1,206 +0,0 @@
import dataclasses
import enum
import logging
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import tyro
import rospy
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from agilex_utils import RosOperator
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
AGILEX_ALOHA = "agilex_arx_3camera_aloha"
@dataclasses.dataclass
class Args:
host: str = "172.20.103.171"
port: int = 8090
env: EnvMode = EnvMode.AGILEX_ALOHA
num_steps: int = 10
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
EnvMode.AGILEX_ALOHA: observation_agilex_3camera_aloha,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
logging.info(f"Server metadata: {policy.get_server_metadata()}")
args_ros, ros_operator = init_agilex_3camera_aloha()
# Send 1 observation to make sure the model is loaded.
policy.infer(obs_fn(args_ros, ros_operator))
# test inference
start = time.time()
for _ in range(10):
policy.infer(obs_fn(args_ros, ros_operator))
end = time.time()
print(f"Total time taken: {end - start:.2f} s")
print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
if 1000 * (end - start) / args.num_steps < 500:
logging.info("Inference time is less than 0.5 second! Its good!")
else:
logging.warning("Inference time is more than 0.5 second! Its bad!")
# pub
master_arm_left_publisher = rospy.Publisher(args_ros.master_arm_left_topic, JointState, queue_size=10)
master_arm_right_publisher = rospy.Publisher(args_ros.master_arm_right_topic, JointState, queue_size=10)
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
rate = rospy.Rate(30)
# 默认速度和力矩值
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
0.6527481079101562, -0.013187408447265625, -0.013187408447265625,
0.0, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
while True:
actions = policy.infer(obs_fn(args_ros, ros_operator))['actions']
for idx, action in enumerate(actions):
if(rospy.is_shutdown()):
break
# print(action)
print(idx, np.round(action[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = action[:7]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[:7]
# import pdb
# pdb.set_trace()
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = action[7:]
joint_state_msg.velocity = last_velocity[7:]
joint_state_msg.effort = last_effort[7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
def init_agilex_3camera_aloha():
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.img_left_topic = '/camera_l/color/image_raw'
args.img_right_topic = '/camera_r/color/image_raw'
args.img_front_topic = '/camera_f/color/image_raw'
args.master_arm_left_topic = '/master/joint_left'
args.master_arm_right_topic = '/master/joint_right'
args.puppet_arm_left_topic = '/puppet/joint_left'
args.puppet_arm_right_topic = '/puppet/joint_right'
args.publish_rate = 30
args.use_robot_base = False
args.use_actions_interpolation = False
args.use_depth_image = False
ros_operator = RosOperator(args)
return args, ros_operator
def observation_agilex_3camera_aloha(args, ros_operator: RosOperator):
print_flag = True
rate = rospy.Rate(args.publish_rate)
while True and not rospy.is_shutdown():
result = ros_operator.get_frame()
if not result:
if print_flag:
print("syn fail")
print_flag = False
rate.sleep()
continue
print_flag = True
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base) = result
break
state = np.concatenate([
puppet_arm_left.position, puppet_arm_right.position
])
# a = np.random.randint(256, size=(3, 224, 224), dtype=np.uint8)
img_front = np.transpose(img_front, (2, 0, 1))
img_left = np.transpose(img_left, (2, 0, 1))
img_right = np.transpose(img_right, (2, 0, 1))
return {
"state": state,
"images": {
"cam_high": img_front,
"cam_left_wrist": img_left,
"cam_right_wrist": img_right,
},
"prompt": "weigh a reagent by a balance",
}
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))
# args, ros_operator = init_agilex_3camera_aloha()
# observation_agilex_3camera_aloha(args, ros_operator)
# print()

View File

@@ -1,2 +0,0 @@
numpy
tyro

View File

@@ -1,27 +0,0 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
backports-cached-property==1.0.2
# via tyro
docstring-parser==0.16
# via tyro
eval-type-backport==0.1.3
# via tyro
markdown-it-py==2.2.0
# via rich
mdurl==0.1.2
# via markdown-it-py
numpy==1.21.6
# via -r examples/simple_client/requirements.in
pygments==2.17.2
# via rich
rich==13.8.1
# via tyro
shtab==1.7.1
# via tyro
typing-extensions==4.7.1
# via
# markdown-it-py
# rich
# tyro
tyro==0.9.1
# via -r examples/simple_client/requirements.in

0
examples/policy_records.ipynb Executable file → Normal file
View File

0
examples/simple_client/Dockerfile Executable file → Normal file
View File

0
examples/simple_client/README.md Executable file → Normal file
View File

0
examples/simple_client/compose.yml Executable file → Normal file
View File

0
examples/simple_client/main.py Executable file → Normal file
View File

0
examples/simple_client/requirements.in Executable file → Normal file
View File

0
examples/simple_client/requirements.txt Executable file → Normal file
View File

0
examples/ur5/README.md Executable file → Normal file
View File

0
packages/openpi-client/pyproject.toml Executable file → Normal file
View File

0
packages/openpi-client/src/openpi_client/__init__.py Executable file → Normal file
View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

View File

5
pyproject.toml Executable file → Normal file
View File

@@ -3,7 +3,7 @@ name = "openpi"
version = "0.1.0"
description = "Physical Intelligence open source repo"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
license = { file = "LICENSE" }
dependencies = [
"augmax>=0.3.4",
@@ -21,6 +21,7 @@ dependencies = [
"ml_collections==1.0.0",
"numpy>=1.26.4",
"numpydantic>=1.6.6",
"opencv-python>=4.10.0.84",
"openpi-client",
"orbax-checkpoint==0.11.1",
"pillow>=11.0.0",
@@ -64,7 +65,7 @@ members = ["packages/*"]
[tool.ruff]
line-length = 120
target-version = "py310"
target-version = "py311"
extend-exclude = ["docker", "third_party"]
[tool.ruff.lint]

0
scripts/__init__.py Executable file → Normal file
View File

0
scripts/compute_norm_stats.py Executable file → Normal file
View File

0
scripts/docker/compose.yml Executable file → Normal file
View File

0
scripts/docker/serve_policy.Dockerfile Executable file → Normal file
View File

0
scripts/serve_policy.py Executable file → Normal file
View File

0
scripts/train.py Executable file → Normal file
View File

0
scripts/train_test.py Executable file → Normal file
View File

0
src/openpi/__init__.py Executable file → Normal file
View File

0
src/openpi/conftest.py Executable file → Normal file
View File

0
src/openpi/models/__init__.py Executable file → Normal file
View File

View File

@@ -0,0 +1,466 @@
import math
from typing import Literal
import chex
from einops import einops
from flax import linen as nn
from flax.linen.module import Module
from flax.linen.module import compact
from flax.struct import dataclass
from flax.typing import Array
import jax
import jax.numpy as jnp
class FsqCodebook(nn.Module):
input_dim: int
target_codebook_size: int
codebook_type: Literal["fsq", "lfq"]
_bins_per_dim: tuple[int] | None = None
@property
def bins_per_dim(self):
if self._bins_per_dim is not None:
return self._bins_per_dim
if self.codebook_type == "fsq":
return self._get_bins_fsq(self.target_codebook_size)
elif self.codebook_type == "lfq": # noqa: RET505
return self._get_bins_lfq(self.target_codebook_size)
elif self.codebook_type == "custom":
return self._get_bins_custom(self.target_codebook_size)
else:
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
@property
def place_values(self):
place_values = [1]
for b in self.bins_per_dim[:-1]:
place_values.append(place_values[-1] * b)
return jnp.array(place_values)
@staticmethod
def _get_bins_fsq(target_codebook_size):
"""
Get bins per dimension based on codebook size, from the original FSQ paper.
"""
if target_codebook_size == 2**8:
return (8, 6, 5)
elif target_codebook_size == 2**10: # noqa: RET505
return (8, 5, 5, 5)
elif target_codebook_size == 2**12:
return (7, 5, 5, 5, 5)
elif target_codebook_size == 2**14:
return (8, 8, 8, 6, 5)
elif target_codebook_size == 2**16:
return (8, 8, 8, 5, 5, 5)
else:
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
@staticmethod
def _get_bins_custom(target_codebook_size):
if target_codebook_size == 2**8:
return (16, 16)
elif target_codebook_size == 2**10: # noqa: RET505
return (32, 32)
elif target_codebook_size == 2**12:
return (64, 64)
elif target_codebook_size == 2**14:
return (128, 128)
elif target_codebook_size == 2**16:
return (256, 256)
return None
@staticmethod
def _get_bins_lfq(target_codebook_size):
"""
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
"""
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
return (2,) * int(math.log2(target_codebook_size))
def setup(self):
self.proj_down = nn.Dense(len(self.bins_per_dim))
self.proj_up = nn.Dense(self.input_dim)
def __call__(self, inputs):
tokens, z = self.encode(inputs)
output = self.decode(tokens, z_grad=z)
return tokens, output
def encode(self, inputs):
bases = jnp.array(self.bins_per_dim)
x = self.proj_down(inputs)
z = jnp.tanh(x)
# Quantize
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
tokens = self.undigitize(digits)
return tokens, z
def decode(self, tokens, z_grad: jax.Array | None = None):
bases = jnp.array(self.bins_per_dim)
digits = self.digitize(tokens)
z_q = digits / (bases - 1) * 2 - 1
if z_grad is not None:
chex.assert_equal_shape([z_q, z_grad])
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
return self.proj_up(z_q)
def undigitize(self, digits):
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
def digitize(self, tokens):
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
@property
def vocab_size(self):
return math.prod(self.bins_per_dim)
class ResNetDownBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x, *, train=True):
skip = x
if self.stride > 1 or x.shape[-1] != self.n_filters:
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
class ResNetUpBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x, *, train=True):
skip = x
if self.stride > 1:
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
@dataclass
class LfqCodebookOutput:
tokens: jnp.ndarray
z: jnp.ndarray
z_q: jnp.ndarray
token_log_probs: jnp.ndarray
commit_loss: jnp.ndarray
class LookupFreeQuantization(nn.Module):
num_dims: int
latent_dim: int
def setup(self):
self.codebook = jnp.array([-1, 1])
# self.activation = lambda x: x
self.activation = nn.tanh
self.project_down = nn.Dense(self.num_dims)
self.project_up = nn.Dense(self.latent_dim)
def encode(self, z):
z = self.project_down(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
token_bits = jnp.argmin(token_squared_distances, axis=-1)
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
def decode(self, tokens):
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
return self.project_up(self.codebook[token_bits])
def loss(self, x):
z = self.project_down(x)
z = self.activation(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
tokens = jnp.argmin(token_squared_distances, axis=-1)
token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
token_bit_expansions = jnp.bitwise_and(
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
).astype(jnp.int32)
token_log_probs = (
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
+ token_bit_log_probs[..., 1] @ token_bit_expansions
) # (batch_size, num_tokens, 2 ** num_dims)
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
z_q = self.codebook[tokens]
commit_loss = jnp.square(z - z_q).mean()
z_q = jax.lax.stop_gradient(z_q - z) + z
z_q = self.project_up(z_q)
z = self.project_up(z)
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
return LfqCodebookOutput(
tokens=tokens,
z=z,
z_q=z_q,
token_log_probs=jnp.zeros(()),
commit_loss=commit_loss,
)
def make_block_causal_attention_matrix(q, k, bs_q, bs_k):
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
class GeGLU(Module):
"""Gated Linear Unit with GELU (GeGLU) activation function.
GeGLU is a Flax layer that combines a linear transformation with a GELU
activation function in a gating mechanism. It is often used in Transformer models
to provide non-linear capabilities while preserving a strong linear component.
Example usage::
>>> import flax.linen as nn
>>> class TransformerBlock(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(2)(x)
... x = nn.GeGLU()(x) # initialized
... return x
Attributes:
features: the number of output features (default: None).
"""
output_dim: int = -1
@compact
def __call__(self, inputs: Array) -> Array:
"""Applies the GeGLU activation to the inputs.
Args:
inputs: the nd-array to apply the GeGLU activation function to.
Returns:
The transformed input.
"""
if self.output_dim == -1:
output_dim = inputs.shape[-1]
else:
output_dim = self.output_dim
x = nn.Dense(output_dim * 2)(inputs)
x, gate = x[..., :output_dim], x[..., output_dim:]
return x * nn.gelu(gate)
class CrossAttentionLayer(nn.Module):
dropout_rate: float = 0.0
num_heads: int = None
causal: bool = False
mlp_ratio: float = 4.0
@nn.compact
def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
d_embed = x.shape[-1]
seq_len_q = x.shape[-2]
seq_len_k = y.shape[-2]
if self.causal:
# One block size will be 1
bs_q = max(seq_len_q // seq_len_k, 1)
bs_k = max(seq_len_k // seq_len_q, 1)
mask_self = nn.make_causal_mask(x[..., 0])
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
# Self-attention block
skip = x
x = nn.LayerNorm()(x)
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
)(x, x, x, mask=mask_self)
x = skip + x
# Cross-attention block
skip = x
x = nn.LayerNorm()(x)
# bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
# attention_fn=partial(nn.dot_product_attention, bias=bias),
)(x, y, y, mask=mask_cross)
x = skip + x
# MLP block
skip = x
x = nn.LayerNorm()(x)
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = GeGLU()(x)
x = nn.Dense(d_embed)(x)
return skip + x
def sinusoidal_pe_init(_, shape):
seq_len, d_embed = shape
position = jnp.arange(0, seq_len, 1)
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
return jnp.concatenate(
[
jnp.sin(position[:, jnp.newaxis] * div_term),
jnp.cos(position[:, jnp.newaxis] * div_term),
],
axis=-1,
)
class TokenizerEncoderDecoder(nn.Module):
num_tokens: int
num_cross_tokens: int
num_layers: int
causal: bool
mlp_ratio: float = 4.0
use_state_conditioning: bool = False
@nn.compact
def __call__(self, y, *, train=True, state_conditioning=None, mask=None):
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
if mask is not None:
# mask is (batch_dims..., num_cross_tokens)
chex.assert_equal_shape([y[..., 0], mask])
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
else:
attn_mask = jnp.ones(y.shape[:-2] + (1, self.num_tokens, self.num_cross_tokens))
if self.use_state_conditioning:
assert state_conditioning is not None, "State conditioning is required for this model."
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
y = jnp.concatenate([y, state_embed], axis=-2)
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
for _ in range(self.num_layers):
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
x, y, train=train, mask_self=None, mask_cross=attn_mask
)
return x
class FsqAttentionTokenizer(nn.Module):
embed_dim: int
data_dim: int
data_horizon: int
num_tokens: int
num_layers: int
target_codebook_size: int
causal: bool = False
mlp_ratio: float = 2.0
bound: float | None = None
use_state_conditioning: bool = False
@property
def vocab_size(self):
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size))
def setup(self):
self.proj = nn.Dense(self.embed_dim)
self.encoder = TokenizerEncoderDecoder(
num_tokens=self.num_tokens,
num_cross_tokens=self.data_horizon,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.codebook = FsqCodebook(
input_dim=self.embed_dim,
target_codebook_size=self.target_codebook_size,
codebook_type="custom",
)
self.decoder = TokenizerEncoderDecoder(
num_tokens=self.data_horizon,
num_cross_tokens=self.num_tokens,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.proj_mean = nn.Dense(self.data_dim)
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
def tokenize(self, action, *, obs=None, train=False):
if self.bound is not None:
action = jnp.clip(action, -self.bound, self.bound)
x = self.proj(action)
x = self.encoder(x, train=train, state_conditioning=obs)
return self.codebook.encode(x)
def detokenize(self, tokens, *, obs=None):
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
mean = self.proj_mean(x)
return mean * self.out_scale
def loss(self, action, *, obs=None, train=True):
# Encode
x = self.proj(action)
z = self.encoder(x, train=train, state_conditioning=obs)
# Quantize
tokens, z = self.codebook(z)
# Decode
x = self.decoder(z, train=train, state_conditioning=obs)
mean = self.proj_mean(x) * self.out_scale
mse = jnp.mean(jnp.square(action - mean))
mae = jnp.mean(jnp.abs(action - mean))
return mse, {
"mse": mse,
"mae": mae,
}
def __call__(self, *args, **kwargs):
"""
Dummy for .init
"""
return self.loss(*args, **kwargs)

0
src/openpi/models/gemma.py Executable file → Normal file
View File

0
src/openpi/models/gemma_fast.py Executable file → Normal file
View File

Some files were not shown because too many files have changed in this diff Show More