Initial commit
This commit is contained in:
70
examples/aloha_real/Dockerfile
Normal file
70
examples/aloha_real/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
73
examples/aloha_real/README.md
Normal file
73
examples/aloha_real/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='toast out of toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python examples/aloha_real/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch --wait aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='toast out of toaster'
|
||||
```
|
||||
|
||||
## Model Guide
|
||||
The Pi0 Base Model is an out-of-the-box model for general tasks. You can find more details in the [technical report](https://www.physicalintelligence.company/download/pi0.pdf).
|
||||
|
||||
While we strongly recommend fine-tuning the model to your own data to adapt it to particular tasks, it may be possible to prompt the model to attempt some tasks that were in the pre-training data. For example, below is a video of the model attempting the "toast out of toaster" task.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/Physical-Intelligence/openpi/blob/main/examples/aloha_real/toast.gif" alt="toast out of toaster"/>
|
||||
</p>
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name=<your-config-name>,
|
||||
data=LeRobotAlohaDataConfig(
|
||||
repo_id=<your-repo-id>,
|
||||
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
|
||||
),
|
||||
),
|
||||
```
|
||||
|
||||
Run the training script:
|
||||
|
||||
```bash
|
||||
uv run scripts/train.py <your-config-name>
|
||||
```
|
||||
63
examples/aloha_real/compose.yml
Normal file
63
examples/aloha_real/compose.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
examples/aloha_real/constants.py
Normal file
71
examples/aloha_real/constants.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
52
examples/aloha_real/env.py
Normal file
52
examples/aloha_real/env.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(self, render_height: int = 480, render_width: int = 640) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
images = []
|
||||
for cam_name in obs["images"]:
|
||||
curr_image = obs["images"][cam_name]
|
||||
curr_image = einops.rearrange(curr_image, "h w c -> c h w")
|
||||
images.append(curr_image)
|
||||
stacked_images = np.stack(images, axis=0).astype(np.uint8)
|
||||
|
||||
# TODO: Consider removing these transformations.
|
||||
return {
|
||||
"qpos": obs["qpos"],
|
||||
"image": stacked_images,
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["qpos"])
|
||||
42
examples/aloha_real/main.py
Normal file
42
examples/aloha_real/main.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
167
examples/aloha_real/real_env.py
Normal file
167
examples/aloha_real/real_env.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, setup_robots: bool = True):
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
reset_position = [0, -1.5, 1.5, 0, 0, 0]
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, setup_robots=setup_robots)
|
||||
18
examples/aloha_real/requirements.in
Normal file
18
examples/aloha_real/requirements.in
Normal file
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
examples/aloha_real/requirements.txt
Normal file
156
examples/aloha_real/requirements.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
examples/aloha_real/robot_utils.py
Normal file
275
examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
BIN
examples/aloha_real/toast.gif
Normal file
BIN
examples/aloha_real/toast.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 23 MiB |
36
examples/aloha_real/video_display.py
Normal file
36
examples/aloha_real/video_display.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
41
examples/aloha_sim/Dockerfile
Normal file
41
examples/aloha_sim/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for the Aloha simulation environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
||||
|
||||
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev
|
||||
ENV MUJOCO_GL=egl
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
||||
36
examples/aloha_sim/README.md
Normal file
36
examples/aloha_sim/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Run Aloha Sim
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_sim/.venv
|
||||
source examples/aloha_sim/.venv/bin/activate
|
||||
uv pip sync examples/aloha_sim/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the simulation
|
||||
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
||||
```
|
||||
|
||||
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env ALOHA_SIM
|
||||
```
|
||||
39
examples/aloha_sim/compose.yml
Normal file
39
examples/aloha_sim/compose.yml
Normal file
@@ -0,0 +1,39 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_sim
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_sim/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
56
examples/aloha_sim/env.py
Normal file
56
examples/aloha_sim/env.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import gym_aloha # noqa: F401
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class AlohaSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self._gym = gymnasium.make(task, obs_type=obs_type)
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action["qpos"])
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1))
|
||||
|
||||
# Add multi-camera dimension, to match the way real aloha provides images as [cam_idx, C, H, W].
|
||||
imgs = np.expand_dims(img, axis=0)
|
||||
|
||||
return {
|
||||
"qpos": gym_obs["agent_pos"],
|
||||
"image": imgs,
|
||||
}
|
||||
55
examples/aloha_sim/main.py
Normal file
55
examples/aloha_sim/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import env as _env
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import saver as _saver
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
out_path: pathlib.Path = pathlib.Path("out.mp4")
|
||||
|
||||
task: str = "gym_aloha/AlohaTransferCube-v0"
|
||||
seed: int = 0
|
||||
|
||||
action_horizon: int = 10
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
display: bool = False
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaSimEnvironment(
|
||||
task=args.task,
|
||||
seed=args.seed,
|
||||
),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[
|
||||
_saver.VideoSaver(args.out_path),
|
||||
],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
8
examples/aloha_sim/requirements.in
Normal file
8
examples/aloha_sim/requirements.in
Normal file
@@ -0,0 +1,8 @@
|
||||
gym-aloha
|
||||
imageio
|
||||
matplotlib
|
||||
msgpack
|
||||
numpy
|
||||
typing-extensions
|
||||
tyro
|
||||
websockets
|
||||
132
examples/aloha_sim/requirements.txt
Normal file
132
examples/aloha_sim/requirements.txt
Normal file
@@ -0,0 +1,132 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cloudpickle==3.1.0
|
||||
# via gymnasium
|
||||
contourpy==1.3.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
gym-aloha==0.1.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
gymnasium==1.0.0
|
||||
# via gym-aloha
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.36.1
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gym-aloha
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.9.3
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# gymnasium
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# scipy
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==11.0.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.0
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
requests==2.32.3
|
||||
# via dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
scipy==1.14.1
|
||||
# via dm-control
|
||||
setuptools==75.6.0
|
||||
# via
|
||||
# dm-control
|
||||
# imageio-ffmpeg
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.1
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gymnasium
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
35
examples/aloha_sim/saver.py
Normal file
35
examples/aloha_sim/saver.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoSaver(_subscriber.Subscriber):
|
||||
"""Saves episode data."""
|
||||
|
||||
def __init__(self, out_path: pathlib.Path, subsample: int = 1) -> None:
|
||||
self._out_path = out_path
|
||||
self._images: list[np.ndarray] = []
|
||||
self._subsample = subsample
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
self._images = []
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
self._images.append(im)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
logging.info(f"Saving video to {self._out_path}")
|
||||
imageio.mimwrite(
|
||||
self._out_path,
|
||||
[np.asarray(x) for x in self._images[:: self._subsample]],
|
||||
fps=50 // max(1, self._subsample),
|
||||
)
|
||||
65
examples/calvin/Dockerfile
Normal file
65
examples/calvin/Dockerfile
Normal file
@@ -0,0 +1,65 @@
|
||||
# THIS DOCKERFILE DOES NOT YET WORK
|
||||
# Dockerfile for the CALVIN benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t calvin -f examples/calvin/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --privileged --gpus all calvin /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
make \
|
||||
g++ \
|
||||
git \
|
||||
wget \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6 \
|
||||
unzip \
|
||||
ffmpeg
|
||||
|
||||
# Install miniconda
|
||||
ENV CONDA_DIR=/opt/conda
|
||||
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
|
||||
/bin/bash ~/miniconda.sh -b -p $CONDA_DIR
|
||||
ENV PATH=$CONDA_DIR/bin:$PATH
|
||||
|
||||
# Submodules don't work with calvin because it internally parses git metadata.
|
||||
# So we have to clone it directly.
|
||||
RUN git clone --recurse-submodules https://github.com/mees/calvin.git /root/calvin
|
||||
|
||||
RUN conda create -n calvin python=3.8
|
||||
RUN source /opt/conda/bin/activate calvin && \
|
||||
pip install setuptools==57.5.0 && \
|
||||
cd /root/calvin && \
|
||||
./install.sh && \
|
||||
pip install \
|
||||
imageio[ffmpeg] \
|
||||
moviepy \
|
||||
numpy==1.23.0 \
|
||||
tqdm \
|
||||
tyro \
|
||||
websockets \
|
||||
msgpack
|
||||
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src
|
||||
|
||||
# Download CALVIN dataset, see https://github.com/mees/calvin/blob/main/dataset/download_data.sh
|
||||
RUN mkdir -p /datasets && cd /datasets && \
|
||||
wget http://calvin.cs.uni-freiburg.de/dataset/calvin_debug_dataset.zip && \
|
||||
unzip calvin_debug_dataset.zip && \
|
||||
rm calvin_debug_dataset.zip
|
||||
|
||||
WORKDIR /app
|
||||
CMD ["/bin/bash", "-c", "source /opt/conda/bin/activate calvin && python examples/calvin/main.py"]
|
||||
47
examples/calvin/README.md
Normal file
47
examples/calvin/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# CALVIN Benchmark
|
||||
|
||||
This example runs the CALVIN benchmark: https://github.com/mees/calvin
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env CALVIN"
|
||||
docker compose -f examples/calvin/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
cd $OPENPI_ROOT
|
||||
conda create -n calvin python=3.8
|
||||
conda activate calvin
|
||||
|
||||
git clone --recurse-submodules https://github.com/mees/calvin.git
|
||||
cd calvin
|
||||
pip install setuptools==57.5.0
|
||||
./install.sh
|
||||
|
||||
pip install imageio[ffmpeg] moviepy numpy==1.23.0 tqdm tyro websockets msgpack
|
||||
ENV PYTHONPATH=$PYTHONPATH:$OPENPI_ROOT/packages/openpi-client/src
|
||||
|
||||
# Download CALVIN dataset, see https://github.com/mees/calvin/blob/main/dataset/download_data.sh
|
||||
export CALVIN_DATASETS_DIR=~/datasets
|
||||
export CALVIN_DATASET=calvin_debug_dataset
|
||||
mkdir -p $CALVIN_DATASETS_DIR && cd $CALVIN_DATASETS_DIR
|
||||
wget http://calvin.cs.uni-freiburg.de/dataset/$CALVIN_DATASET.zip
|
||||
unzip $CALVIN_DATASET.zip
|
||||
rm $CALVIN_DATASET.zip
|
||||
|
||||
# Run the simulation
|
||||
cd $OPENPI_ROOT
|
||||
python examples/calvin/main.py --args.calvin_data_path=$CALVIN_DATASETS_DIR
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env CALVIN
|
||||
```
|
||||
46
examples/calvin/compose.yml
Normal file
46
examples/calvin/compose.yml
Normal file
@@ -0,0 +1,46 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/calvin/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: calvin
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/calvin/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
175
examples/calvin/main.py
Normal file
175
examples/calvin/main.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Runs a model in a CALVIN simulation environment."""
|
||||
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
from calvin_agent.evaluation.multistep_sequences import get_sequences
|
||||
from calvin_agent.evaluation.utils import get_env_state_for_initial_condition
|
||||
import calvin_env
|
||||
from calvin_env.envs.play_table_env import get_env
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# CALVIN environment-specific parameters
|
||||
#################################################################################################################
|
||||
calvin_data_path: str = "/datasets/calvin_debug_dataset" # Path to CALVIN dataset for loading validation tasks
|
||||
max_subtask_steps: int = 360 # Max number of steps per subtask
|
||||
num_trials: int = 1000 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/calvin/videos" # Path to save videos
|
||||
num_save_videos: int = 5 # Number of videos to be logged per task
|
||||
video_temp_subsample: int = 5 # Temporal subsampling to make videos shorter
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize CALVIN environment
|
||||
env = get_env(pathlib.Path(args.calvin_data_path) / "validation", show_gui=False)
|
||||
|
||||
# Get CALVIN eval task set
|
||||
task_definitions, task_instructions, task_reward = _get_calvin_tasks_and_reward(args.num_trials)
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation.
|
||||
episode_solved_subtasks = []
|
||||
per_subtask_success = collections.defaultdict(list)
|
||||
for i, (initial_state, task_sequence) in enumerate(tqdm.tqdm(task_definitions)):
|
||||
logging.info(f"Starting episode {i+1}...")
|
||||
logging.info(f"Task sequence: {task_sequence}")
|
||||
|
||||
# Reset env to initial position for task
|
||||
robot_obs, scene_obs = get_env_state_for_initial_condition(initial_state)
|
||||
env.reset(robot_obs=robot_obs, scene_obs=scene_obs)
|
||||
|
||||
rollout_images = []
|
||||
solved_subtasks = 0
|
||||
for subtask in task_sequence:
|
||||
start_info = env.get_info()
|
||||
action_plan = collections.deque()
|
||||
|
||||
obs = env.get_obs()
|
||||
done = False
|
||||
for _ in range(args.max_subtask_steps):
|
||||
img = obs["rgb_obs"]["rgb_static"]
|
||||
wrist_img = obs["rgb_obs"]["rgb_gripper"]
|
||||
rollout_images.append(img.transpose(2, 0, 1))
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/rgb_static": img,
|
||||
"observation/rgb_gripper": wrist_img,
|
||||
"observation/state": obs["robot_obs"],
|
||||
"prompt": str(task_instructions[subtask][0]),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Round gripper action since env expects gripper_action in (-1, 1)
|
||||
action[-1] = 1 if action[-1] > 0 else -1
|
||||
|
||||
# Step environment
|
||||
obs, _, _, current_info = env.step(action)
|
||||
|
||||
# check if current step solves a task
|
||||
current_task_info = task_reward.get_task_info_for_set(start_info, current_info, {subtask})
|
||||
if len(current_task_info) > 0:
|
||||
done = True
|
||||
solved_subtasks += 1
|
||||
break
|
||||
|
||||
per_subtask_success[subtask].append(int(done))
|
||||
if not done:
|
||||
# Subtask execution failed --> stop episode
|
||||
break
|
||||
|
||||
episode_solved_subtasks.append(solved_subtasks)
|
||||
if len(episode_solved_subtasks) < args.num_save_videos:
|
||||
# Save rollout video.
|
||||
idx = len(episode_solved_subtasks)
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{idx}.mp4",
|
||||
[np.asarray(x) for x in rollout_images[:: args.video_temp_subsample]],
|
||||
fps=50 // args.video_temp_subsample,
|
||||
)
|
||||
|
||||
# Print current performance after each episode
|
||||
logging.info(f"Solved subtasks: {solved_subtasks}")
|
||||
_calvin_print_performance(episode_solved_subtasks, per_subtask_success)
|
||||
|
||||
# Log final performance
|
||||
logging.info(f"results/avg_num_subtasks: : {np.mean(episode_solved_subtasks)}")
|
||||
for i in range(1, 6):
|
||||
# Compute fraction of episodes that have *at least* i successful subtasks
|
||||
logging.info(
|
||||
f"results/avg_success_len_{i}: {np.sum(episode_solved_subtasks >= i) / len(episode_solved_subtasks)}"
|
||||
)
|
||||
for key in per_subtask_success:
|
||||
logging.info(f"results/avg_success__{key}: {np.mean(per_subtask_success[key])}")
|
||||
|
||||
|
||||
def _get_calvin_tasks_and_reward(num_sequences):
|
||||
conf_dir = pathlib.Path(calvin_env.__file__).absolute().parents[2] / "calvin_models" / "conf"
|
||||
task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml")
|
||||
task_oracle = hydra.utils.instantiate(task_cfg)
|
||||
val_annotations = OmegaConf.load(conf_dir / "annotations/new_playtable_validation.yaml")
|
||||
eval_sequences = get_sequences(num_sequences)
|
||||
return eval_sequences, val_annotations, task_oracle
|
||||
|
||||
|
||||
def _calvin_print_performance(episode_solved_subtasks, per_subtask_success):
|
||||
# Compute avg success rate per task length
|
||||
logging.info("#####################################################")
|
||||
logging.info(f"Avg solved subtasks: {np.mean(episode_solved_subtasks)}\n")
|
||||
|
||||
logging.info("Per sequence_length avg success:")
|
||||
for i in range(1, 6):
|
||||
# Compute fraction of episodes that have *at least* i successful subtasks
|
||||
logging.info(f"{i}: {np.sum(np.array(episode_solved_subtasks) >= i) / len(episode_solved_subtasks) * 100}%")
|
||||
|
||||
logging.info("\n Per subtask avg success:")
|
||||
for key in per_subtask_success:
|
||||
logging.info(f"{key}: \t\t\t {np.mean(per_subtask_success[key]) * 100}%")
|
||||
logging.info("#####################################################")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(main)
|
||||
59
examples/libero/Dockerfile
Normal file
59
examples/libero/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Dockerfile for the LIBERO benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t libero -f examples/libero/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
make \
|
||||
g++ \
|
||||
clang \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
|
||||
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
|
||||
|
||||
# Create a default config file to avoid an input prompt from LIBERO's init script.
|
||||
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
|
||||
ENV LIBERO_CONFIG_PATH=/tmp/libero
|
||||
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
|
||||
benchmark_root: /app/third_party/libero/libero/libero
|
||||
bddl_files: /app/third_party/libero/libero/libero/bddl_files
|
||||
init_states: /app/third_party/libero/libero/libero/init_files
|
||||
datasets: /app/third_party/libero/libero/datasets
|
||||
assets: /app/third_party/libero/libero/libero/assets
|
||||
EOF
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"]
|
||||
39
examples/libero/README.md
Normal file
39
examples/libero/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# LIBERO Benchmark
|
||||
|
||||
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
|
||||
|
||||
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
# Grant access to the X11 server:
|
||||
sudo xhost +local:docker
|
||||
|
||||
export SERVER_ARGS="--env LIBERO"
|
||||
docker compose -f examples/libero/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.8 examples/libero/.venv
|
||||
source examples/libero/.venv/bin/activate
|
||||
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
uv pip install -e packages/openpi-client
|
||||
uv pip install -e third_party/libero
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
|
||||
|
||||
# Run the simulation
|
||||
python examples/libero/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env LIBERO
|
||||
```
|
||||
49
examples/libero/compose.yml
Normal file
49
examples/libero/compose.yml
Normal file
@@ -0,0 +1,49 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/libero/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: libero
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/libero/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix:ro
|
||||
environment:
|
||||
- DISPLAY=$DISPLAY
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
215
examples/libero/main.py
Normal file
215
examples/libero/main.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
from libero.libero import benchmark
|
||||
from libero.libero import get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
|
||||
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
resize_size: int = 224
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# LIBERO environment-specific parameters
|
||||
#################################################################################################################
|
||||
task_suite_name: str = (
|
||||
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
||||
)
|
||||
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
|
||||
num_trials_per_task: int = 50 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/libero/videos" # Path to save videos
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def eval_libero(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize LIBERO task suite
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[args.task_suite_name]()
|
||||
num_tasks_in_suite = task_suite.n_tasks
|
||||
logging.info(f"Task suite: {args.task_suite_name}")
|
||||
|
||||
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.task_suite_name == "libero_spatial":
|
||||
max_steps = 220 # longest training demo has 193 steps
|
||||
elif args.task_suite_name == "libero_object":
|
||||
max_steps = 280 # longest training demo has 254 steps
|
||||
elif args.task_suite_name == "libero_goal":
|
||||
max_steps = 300 # longest training demo has 270 steps
|
||||
elif args.task_suite_name == "libero_10":
|
||||
max_steps = 520 # longest training demo has 505 steps
|
||||
elif args.task_suite_name == "libero_90":
|
||||
max_steps = 400 # longest training demo has 373 steps
|
||||
else:
|
||||
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation
|
||||
total_episodes, total_successes = 0, 0
|
||||
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
||||
# Get task
|
||||
task = task_suite.get_task(task_id)
|
||||
|
||||
# Get default LIBERO initial states
|
||||
initial_states = task_suite.get_task_init_states(task_id)
|
||||
|
||||
# Initialize LIBERO environment and task description
|
||||
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
|
||||
|
||||
# Start episodes
|
||||
task_episodes, task_successes = 0, 0
|
||||
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
|
||||
logging.info(f"\nTask: {task_description}")
|
||||
|
||||
# Reset environment
|
||||
env.reset()
|
||||
action_plan = collections.deque()
|
||||
|
||||
# Set initial states
|
||||
obs = env.set_init_state(initial_states[episode_idx])
|
||||
|
||||
# Setup
|
||||
t = 0
|
||||
replay_images = []
|
||||
|
||||
logging.info(f"Starting episode {task_episodes+1}...")
|
||||
while t < max_steps + args.num_steps_wait:
|
||||
try:
|
||||
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
||||
# and we need to wait for them to fall
|
||||
if t < args.num_steps_wait:
|
||||
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# Get preprocessed image
|
||||
# IMPORTANT: rotate 180 degrees to match train preprocessing
|
||||
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
|
||||
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
|
||||
img = image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
|
||||
wrist_img = image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
|
||||
|
||||
# Save preprocessed image for replay video
|
||||
replay_images.append(img)
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/image": img,
|
||||
"observation/wrist_image": wrist_img,
|
||||
"observation/state": np.concatenate(
|
||||
(
|
||||
obs["robot0_eef_pos"],
|
||||
_quat2axisangle(obs["robot0_eef_quat"]),
|
||||
obs["robot0_gripper_qpos"],
|
||||
)
|
||||
),
|
||||
"prompt": str(task_description),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Execute action in environment
|
||||
obs, reward, done, info = env.step(action.tolist())
|
||||
if done:
|
||||
task_successes += 1
|
||||
total_successes += 1
|
||||
break
|
||||
t += 1
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Caught exception: {e}")
|
||||
break
|
||||
|
||||
task_episodes += 1
|
||||
total_episodes += 1
|
||||
|
||||
# Save a replay video of the episode
|
||||
suffix = "success" if done else "failure"
|
||||
task_segment = task_description.replace(" ", "_")
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
|
||||
[np.asarray(x) for x in replay_images],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
# Log current results
|
||||
logging.info(f"Success: {done}")
|
||||
logging.info(f"# episodes completed so far: {total_episodes}")
|
||||
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
||||
|
||||
# Log final results
|
||||
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
||||
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
|
||||
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
logging.info(f"Total episodes: {total_episodes}")
|
||||
|
||||
|
||||
def _get_libero_env(task, resolution, seed):
|
||||
"""Initializes and returns the LIBERO environment, along with the task description."""
|
||||
task_description = task.language
|
||||
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
|
||||
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
||||
return env, task_description
|
||||
|
||||
|
||||
def _quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(eval_libero)
|
||||
11
examples/libero/requirements.in
Normal file
11
examples/libero/requirements.in
Normal file
@@ -0,0 +1,11 @@
|
||||
imageio[ffmpeg]
|
||||
numpy==1.22.4
|
||||
tqdm
|
||||
tyro
|
||||
PyYaml
|
||||
opencv-python==4.6.0.66
|
||||
torch==1.11.0+cu113
|
||||
torchvision==0.12.0+cu113
|
||||
torchaudio==0.11.0+cu113
|
||||
robosuite==1.4.1
|
||||
matplotlib==3.5.3
|
||||
136
examples/libero/requirements.txt
Normal file
136
examples/libero/requirements.txt
Normal file
@@ -0,0 +1,136 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
|
||||
absl-py==2.1.0
|
||||
# via mujoco
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
eval-type-backport==0.2.0
|
||||
# via tyro
|
||||
evdev==1.7.1
|
||||
# via pynput
|
||||
fonttools==4.55.3
|
||||
# via matplotlib
|
||||
glfw==1.12.0
|
||||
# via mujoco
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.35.1
|
||||
# via -r examples/libero/requirements.in
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
importlib-metadata==8.5.0
|
||||
# via typeguard
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
llvmlite==0.36.0
|
||||
# via numba
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.5.3
|
||||
# via -r examples/libero/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mujoco==3.2.3
|
||||
# via robosuite
|
||||
numba==0.53.1
|
||||
# via robosuite
|
||||
numpy==1.22.4
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# imageio
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# robosuite
|
||||
# scipy
|
||||
# torchvision
|
||||
opencv-python==4.6.0.66
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# robosuite
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
# robosuite
|
||||
# torchvision
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pynput==1.7.7
|
||||
# via robosuite
|
||||
pyopengl==3.1.7
|
||||
# via mujoco
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pyyaml==6.0.2
|
||||
# via -r examples/libero/requirements.in
|
||||
requests==2.32.3
|
||||
# via torchvision
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
robosuite==1.4.1
|
||||
# via -r examples/libero/requirements.in
|
||||
scipy==1.10.1
|
||||
# via robosuite
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# imageio-ffmpeg
|
||||
# numba
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
termcolor==2.4.0
|
||||
# via robosuite
|
||||
torch==1.11.0+cu113
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# torchaudio
|
||||
# torchvision
|
||||
torchaudio==0.11.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
torchvision==0.12.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
tqdm==4.67.1
|
||||
# via -r examples/libero/requirements.in
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# torch
|
||||
# torchvision
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/libero/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
zipp==3.20.2
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
134
examples/policy_records.ipynb
Normal file
134
examples/policy_records.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"record_path = pathlib.Path(\"../policy_records\")\n",
|
||||
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
|
||||
"\n",
|
||||
"records = []\n",
|
||||
"for i in range(num_steps):\n",
|
||||
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
|
||||
" records.append(record)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"length of records\", len(records))\n",
|
||||
"print(\"keys in records\", records[0].keys())\n",
|
||||
"\n",
|
||||
"for k in records[0]:\n",
|
||||
" print(f\"{k} shape: {records[0][k].shape}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_image(step: int, idx: int = 0):\n",
|
||||
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
|
||||
" return img[idx].transpose(1, 2, 0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_image(step: int, idx_lst: list[int]):\n",
|
||||
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
|
||||
" return Image.fromarray(np.hstack(imgs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(2):\n",
|
||||
" display(show_image(i, [0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_axis(name, axis):\n",
|
||||
" return np.array([record[name][axis] for record in records])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# qpos is [..., 14] of type float:\n",
|
||||
"# 0-5: left arm joint angles\n",
|
||||
"# 6: left arm gripper\n",
|
||||
"# 7-12: right arm joint angles\n",
|
||||
"# 13: right arm gripper\n",
|
||||
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_data():\n",
|
||||
" cur_dim = 0\n",
|
||||
" in_data = {}\n",
|
||||
" out_data = {}\n",
|
||||
" for name, dim_size in names:\n",
|
||||
" for i in range(dim_size):\n",
|
||||
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
|
||||
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
|
||||
" cur_dim += 1\n",
|
||||
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"in_data, out_data = make_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in in_data.columns:\n",
|
||||
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
|
||||
" data.plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
32
examples/simple_client/Dockerfile
Normal file
32
examples/simple_client/Dockerfile
Normal file
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/simple_client/main.py"]
|
||||
24
examples/simple_client/README.md
Normal file
24
examples/simple_client/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--example aloha"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py
|
||||
```
|
||||
37
examples/simple_client/compose.yml
Normal file
37
examples/simple_client/compose.yml
Normal file
@@ -0,0 +1,37 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
81
examples/simple_client/main.py
Normal file
81
examples/simple_client/main.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
example: str = "droid"
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
"aloha": _random_observation_aloha,
|
||||
"droid": _random_observation_droid,
|
||||
"calvin": _random_observation_calvin,
|
||||
"libero": _random_observation_libero,
|
||||
}[args.example]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
|
||||
# Send 1 observation to make sure the model is loaded.
|
||||
policy.infer(obs_fn())
|
||||
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
policy.infer(obs_fn())
|
||||
end = time.time()
|
||||
|
||||
print(f"Total time taken: {end - start}")
|
||||
# Note that each inference returns many action chunks.
|
||||
print(f"Inference rate: {100 / (end - start)} Hz")
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"qpos": np.ones((14,)),
|
||||
"image": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_calvin() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(15),
|
||||
"observation/rgb_static": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"observation/rgb_gripper": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"observation/wrist_image": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(main)
|
||||
2
examples/simple_client/requirements.in
Normal file
2
examples/simple_client/requirements.in
Normal file
@@ -0,0 +1,2 @@
|
||||
numpy
|
||||
tyro
|
||||
27
examples/simple_client/requirements.txt
Normal file
27
examples/simple_client/requirements.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
|
||||
backports-cached-property==1.0.2
|
||||
# via tyro
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
eval-type-backport==0.1.3
|
||||
# via tyro
|
||||
markdown-it-py==2.2.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.21.6
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.17.2
|
||||
# via rich
|
||||
rich==13.8.1
|
||||
# via tyro
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
typing-extensions==4.7.1
|
||||
# via
|
||||
# markdown-it-py
|
||||
# rich
|
||||
# tyro
|
||||
tyro==0.9.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
Reference in New Issue
Block a user