Compare commits
47 Commits
02-05-defa
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3827c0e255 | |||
| 55fed92ccc | |||
| 65d864861b | |||
|
|
36dc3c037e | ||
|
|
bb7a3b4a3e | ||
|
|
f37e6dd7fb | ||
|
|
eb28153241 | ||
|
|
16affa3bee | ||
|
|
581e07d73a | ||
|
|
6c514a6d8a | ||
|
|
92b1082442 | ||
|
|
f1b9f4ab71 | ||
|
|
a4b1bf92f1 | ||
|
|
31289dbd72 | ||
|
|
cd0e9a2e0e | ||
|
|
620a56a399 | ||
|
|
42e4838aca | ||
|
|
3409be890e | ||
|
|
d139c700e4 | ||
|
|
d0b6231bd3 | ||
|
|
4a10482dfb | ||
|
|
bf25a4d9c4 | ||
|
|
7dccd73b37 | ||
|
|
29068dd274 | ||
|
|
ba68b3d97b | ||
|
|
cd82848a99 | ||
|
|
8d288e4b41 | ||
|
|
90b87cc42c | ||
|
|
0a67d46b0d | ||
|
|
80d346ea0d | ||
|
|
16788f847e | ||
|
|
f00207b91c | ||
|
|
b7c8bf24d4 | ||
|
|
ed05e55074 | ||
|
|
007e2b91ed | ||
|
|
fa5cf91df1 | ||
|
|
2a13ed7eff | ||
|
|
9c1376bcc1 | ||
|
|
9675e12c4e | ||
|
|
bf30fa3d4c | ||
|
|
f543cb1d87 | ||
|
|
6104624aca | ||
|
|
06cdf3a27f | ||
|
|
f8ce5c9479 | ||
|
|
153e34cefe | ||
|
|
f61cd24a15 | ||
|
|
ed11c29742 |
0
.dockerignore
Normal file → Executable file
0
.dockerignore
Normal file → Executable file
0
.github/CODEOWNERS
vendored
Normal file → Executable file
0
.github/CODEOWNERS
vendored
Normal file → Executable file
0
.github/workflows/pre-commit.yml
vendored
Normal file → Executable file
0
.github/workflows/pre-commit.yml
vendored
Normal file → Executable file
2
.github/workflows/test.yml
vendored
Normal file → Executable file
2
.github/workflows/test.yml
vendored
Normal file → Executable file
@@ -7,7 +7,7 @@ on:
|
||||
jobs:
|
||||
run_tests:
|
||||
name: Run Tests
|
||||
runs-on: verylarge
|
||||
runs-on: openpi-verylarge
|
||||
env:
|
||||
GIT_LFS_SKIP_SMUDGE: true
|
||||
steps:
|
||||
|
||||
2
.gitignore
vendored
Normal file → Executable file
2
.gitignore
vendored
Normal file → Executable file
@@ -12,6 +12,8 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
third-party/*
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
||||
4
.gitmodules
vendored
Normal file → Executable file
4
.gitmodules
vendored
Normal file → Executable file
@@ -1,6 +1,6 @@
|
||||
[submodule "third_party/aloha"]
|
||||
path = third_party/aloha
|
||||
url = git@github.com:Physical-Intelligence/aloha.git
|
||||
url = https://github.com/Physical-Intelligence/aloha.git
|
||||
[submodule "third_party/libero"]
|
||||
path = third_party/libero
|
||||
url = git@github.com:Lifelong-Robot-Learning/LIBERO.git
|
||||
url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
|
||||
|
||||
0
.pre-commit-config.yaml
Normal file → Executable file
0
.pre-commit-config.yaml
Normal file → Executable file
0
.python-version
Normal file → Executable file
0
.python-version
Normal file → Executable file
0
.vscode/settings.json
vendored
Normal file → Executable file
0
.vscode/settings.json
vendored
Normal file → Executable file
0
CONTRIBUTING.md
Normal file → Executable file
0
CONTRIBUTING.md
Normal file → Executable file
9
README.md
Normal file → Executable file
9
README.md
Normal file → Executable file
@@ -38,6 +38,7 @@ We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [
|
||||
|
||||
```bash
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
|
||||
```
|
||||
|
||||
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
|
||||
@@ -66,7 +67,7 @@ We also provide "expert" checkpoints for various robot platforms and tasks. Thes
|
||||
| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `s3://openpi-assets/checkpoints/pi0_droid` |
|
||||
| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on [ALOHA](https://tonyzhaozh.github.io/aloha/) robot platforms | `s3://openpi-assets/checkpoints/pi0_aloha_towel` |
|
||||
| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can unpack food from a tupperware container | `s3://openpi-assets/checkpoints/pi0_aloha_tupperware` |
|
||||
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](XXX), can uncap a pen | `s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
|
||||
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](https://dit-policy.github.io/), can uncap a pen | `s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
|
||||
|
||||
|
||||
By default, checkpoints are automatically downloaded from `s3://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
|
||||
@@ -148,6 +149,8 @@ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp
|
||||
|
||||
The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).
|
||||
|
||||
**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.
|
||||
|
||||
### 3. Spinning up a policy server and running inference
|
||||
|
||||
Once training is complete, we can run inference by spinning up a policy server and then querying it from a Libero evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):
|
||||
@@ -158,12 +161,16 @@ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_libero
|
||||
|
||||
This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run the Libero evaluation script to query the server. For instructions how to install Libero and run the evaluation script, see the [Libero README](examples/libero/README.md).
|
||||
|
||||
If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).
|
||||
|
||||
|
||||
|
||||
### More Examples
|
||||
|
||||
We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
|
||||
- [ALOHA Simulator](examples/aloha_sim)
|
||||
- [ALOHA Real](examples/aloha_real)
|
||||
- [UR5](examples/ur5)
|
||||
|
||||
|
||||
|
||||
|
||||
20
docs/docker.md
Normal file → Executable file
20
docs/docker.md
Normal file → Executable file
@@ -2,6 +2,24 @@
|
||||
|
||||
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
|
||||
|
||||
Docker installation instructions are [here](https://docs.docker.com/engine/install/). If using a GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). If your host machine is Ubuntu 22.04, you can use the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
||||
- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
|
||||
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
|
||||
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
||||
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
|
||||
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
|
||||
|
||||
|
||||
If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
||||
|
||||
Build the Docker image and start the container with the following command:
|
||||
```bash
|
||||
docker compose -f scripts/docker/compose.yml up --build
|
||||
```
|
||||
|
||||
To build and run the Docker image for a specific example, use the following command:
|
||||
```bash
|
||||
docker compose -f examples/<example_name>/compose.yml up --build
|
||||
```
|
||||
where `<example_name>` is the name of the example you want to run.
|
||||
|
||||
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
|
||||
69
docs/norm_stats.md
Executable file
69
docs/norm_stats.md
Executable file
@@ -0,0 +1,69 @@
|
||||
# Normalization statistics
|
||||
|
||||
Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
|
||||
|
||||
## Reloading normalization statistics
|
||||
|
||||
When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
|
||||
|
||||
**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
...
|
||||
data=LeRobotAlohaDataConfig(
|
||||
...
|
||||
assets=AssetsConfig(
|
||||
assets_dir="s3://openpi-assets/checkpoints/pi0_base/assets",
|
||||
asset_id="trossen",
|
||||
),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
||||
|
||||
**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
|
||||
|
||||
**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
|
||||
|
||||
|
||||
## Provided Pre-training Normalization Statistics
|
||||
|
||||
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `s3://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `s3://openpi-assets/checkpoints/pi0_fast_base/assets`.
|
||||
| Robot | Description | Asset ID |
|
||||
|-------|-------------|----------|
|
||||
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
|
||||
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
|
||||
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
|
||||
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
|
||||
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
|
||||
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
|
||||
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
|
||||
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
|
||||
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
|
||||
|
||||
|
||||
## Pi0 Model Action Space Definitions
|
||||
|
||||
Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
|
||||
```
|
||||
"dim_0:dim_5": "left arm joint angles",
|
||||
"dim_6": "left arm gripper position",
|
||||
"dim_7:dim_12": "right arm joint angles (for bi-manual only)",
|
||||
"dim_13": "right arm gripper position (for bi-manual only)",
|
||||
|
||||
# For mobile robots:
|
||||
"dim_14:dim_15": "x-y base velocity (for mobile robots only)",
|
||||
```
|
||||
|
||||
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
|
||||
|
||||
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
|
||||
|
||||
General info for Pi robots:
|
||||
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
|
||||
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
|
||||
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
|
||||
|
||||
For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
|
||||
35
docs/remote_inference.md
Normal file → Executable file
35
docs/remote_inference.md
Normal file → Executable file
@@ -33,10 +33,39 @@ pip install -e .
|
||||
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
|
||||
|
||||
```python
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(host="10.32.255.0", port=8000)
|
||||
action_chunk = policy_client.infer(example)["actions"]
|
||||
# Outside of episode loop, initialize the policy client.
|
||||
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
|
||||
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
|
||||
|
||||
for step in range(num_steps):
|
||||
# Inside the episode loop, construct the observation.
|
||||
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
|
||||
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
|
||||
# The typical resize_size for pre-trained pi0 models is 224.
|
||||
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
|
||||
observation = {
|
||||
"observation/image": image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(img, 224, 224)
|
||||
),
|
||||
"observation/wrist_image": image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(wrist_img, 224, 224)
|
||||
),
|
||||
"observation/state": state,
|
||||
"prompt": task_instruction,
|
||||
}
|
||||
|
||||
# Call the policy server with the current observation.
|
||||
# This returns an action chunk of shape (action_horizon, action_dim).
|
||||
# Note that you typically only need to call the policy every N steps and execute steps
|
||||
# from the predicted action chunk open-loop in the remaining steps.
|
||||
action_chunk = client.infer(observation)["actions"]
|
||||
|
||||
# Execute the actions in the environment.
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `example` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
|
||||
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
|
||||
|
||||
0
examples/aloha_real/Dockerfile
Normal file → Executable file
0
examples/aloha_real/Dockerfile
Normal file → Executable file
8
examples/aloha_real/README.md
Normal file → Executable file
8
examples/aloha_real/README.md
Normal file → Executable file
@@ -1,6 +1,6 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../../openpi/docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
@@ -28,13 +28,13 @@ uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python examples/aloha_real/main.py
|
||||
python -m examples.aloha_real.main
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch --wait aloha ros_nodes.launch
|
||||
roslaunch aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
@@ -123,4 +123,4 @@ This task involves opening a tupperware filled with food and pouring the content
|
||||
|
||||
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
||||
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
|
||||
0
examples/aloha_real/compose.yml
Normal file → Executable file
0
examples/aloha_real/compose.yml
Normal file → Executable file
0
examples/aloha_real/constants.py
Normal file → Executable file
0
examples/aloha_real/constants.py
Normal file → Executable file
2
examples/aloha_real/convert_aloha_data_to_lerobot.py
Normal file → Executable file
2
examples/aloha_real/convert_aloha_data_to_lerobot.py
Normal file → Executable file
@@ -155,7 +155,7 @@ def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, n
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.imdecode(data, 1))
|
||||
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
|
||||
0
examples/aloha_real/env.py
Normal file → Executable file
0
examples/aloha_real/env.py
Normal file → Executable file
0
examples/aloha_real/main.py
Normal file → Executable file
0
examples/aloha_real/main.py
Normal file → Executable file
6
examples/aloha_real/real_env.py
Normal file → Executable file
6
examples/aloha_real/real_env.py
Normal file → Executable file
@@ -49,7 +49,11 @@ class RealEnv:
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_right",
|
||||
init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
0
examples/aloha_real/requirements.in
Normal file → Executable file
0
examples/aloha_real/requirements.in
Normal file → Executable file
0
examples/aloha_real/requirements.txt
Normal file → Executable file
0
examples/aloha_real/requirements.txt
Normal file → Executable file
0
examples/aloha_real/robot_utils.py
Normal file → Executable file
0
examples/aloha_real/robot_utils.py
Normal file → Executable file
0
examples/aloha_real/video_display.py
Normal file → Executable file
0
examples/aloha_real/video_display.py
Normal file → Executable file
70
examples/aloha_real_lyt/Dockerfile
Executable file
70
examples/aloha_real_lyt/Dockerfile
Executable file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
126
examples/aloha_real_lyt/README.md
Executable file
126
examples/aloha_real_lyt/README.md
Executable file
@@ -0,0 +1,126 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python -m examples.aloha_real.main
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
|
||||
```
|
||||
|
||||
## **ALOHA Checkpoint Guide**
|
||||
|
||||
|
||||
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
|
||||
|
||||
While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **Toast Task**
|
||||
|
||||
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_base`
|
||||
- **Prompt**: "take the toast out of the toaster"
|
||||
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
||||
- **Object Distribution**:
|
||||
- Works on both real toast and rubber fake toast
|
||||
- Compatible with standard 2-slice toasters
|
||||
- Works with plates of varying colors
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
||||
|
||||
- The toaster should be positioned in the top-left quadrant of the workspace.
|
||||
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
||||
- The plate should be placed roughly in the lower-center of the workspace.
|
||||
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
||||
|
||||
|
||||
### **Towel Task**
|
||||
|
||||
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_towel`
|
||||
- **Prompt**: "fold the towel"
|
||||
- **Object Distribution**:
|
||||
- Works on towels of varying solid colors
|
||||
- Performance is worse on heavily textured or striped towels
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
||||
|
||||
- The towel should be flattened and roughly centered on the table.
|
||||
- Choose a towel that does not blend in with the table surface.
|
||||
|
||||
|
||||
### **Tupperware Task**
|
||||
|
||||
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
||||
- **Prompt**: "open the tupperware and put the food on the plate"
|
||||
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
||||
- **Object Distribution**:
|
||||
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
||||
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
||||
- The policy has seen plates of varying solid colors.
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
||||
|
||||
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
||||
- Positioning:
|
||||
- Tupperware should be on the left.
|
||||
- Plate should be on the right or bottom.
|
||||
- The tupperware flap should point toward the plate.
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
||||
|
||||
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
||||
|
||||
|
||||
2. Define a training config that uses the custom dataset.
|
||||
|
||||
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
||||
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
66
examples/aloha_real_lyt/compose.yml
Executable file
66
examples/aloha_real_lyt/compose.yml
Executable file
@@ -0,0 +1,66 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
examples/aloha_real_lyt/constants.py
Executable file
71
examples/aloha_real_lyt/constants.py
Executable file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
272
examples/aloha_real_lyt/convert_aloha_data_to_lerobot.py
Executable file
272
examples/aloha_real_lyt/convert_aloha_data_to_lerobot.py
Executable file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
||||
|
||||
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Literal
|
||||
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DatasetConfig:
|
||||
use_videos: bool = True
|
||||
tolerance_s: float = 0.0001
|
||||
image_writer_processes: int = 10
|
||||
image_writer_threads: int = 5
|
||||
video_backend: str | None = None
|
||||
|
||||
|
||||
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
||||
|
||||
|
||||
def create_empty_dataset(
|
||||
repo_id: str,
|
||||
robot_type: str,
|
||||
mode: Literal["video", "image"] = "video",
|
||||
*,
|
||||
has_velocity: bool = False,
|
||||
has_effort: bool = False,
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
) -> LeRobotDataset:
|
||||
motors = [
|
||||
"right_waist",
|
||||
"right_shoulder",
|
||||
"right_elbow",
|
||||
"right_forearm_roll",
|
||||
"right_wrist_angle",
|
||||
"right_wrist_rotate",
|
||||
"right_gripper",
|
||||
"left_waist",
|
||||
"left_shoulder",
|
||||
"left_elbow",
|
||||
"left_forearm_roll",
|
||||
"left_wrist_angle",
|
||||
"left_wrist_rotate",
|
||||
"left_gripper",
|
||||
]
|
||||
cameras = [
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
]
|
||||
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
if has_velocity:
|
||||
features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
if has_effort:
|
||||
features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
for cam in cameras:
|
||||
features[f"observation.images.{cam}"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 480, 640),
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
|
||||
if Path(LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=50,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
use_videos=dataset_config.use_videos,
|
||||
tolerance_s=dataset_config.tolerance_s,
|
||||
image_writer_processes=dataset_config.image_writer_processes,
|
||||
image_writer_threads=dataset_config.image_writer_threads,
|
||||
video_backend=dataset_config.video_backend,
|
||||
)
|
||||
|
||||
|
||||
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
# ignore depth channel, not currently handled
|
||||
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
|
||||
|
||||
def has_velocity(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/qvel" in ep
|
||||
|
||||
|
||||
def has_effort(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/effort" in ep
|
||||
|
||||
|
||||
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
||||
imgs_per_cam = {}
|
||||
for camera in cameras:
|
||||
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||
|
||||
if uncompressed:
|
||||
# load all images in RAM
|
||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
return imgs_per_cam
|
||||
|
||||
|
||||
def load_raw_episode_data(
|
||||
ep_path: Path,
|
||||
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
|
||||
effort = None
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
imgs_per_cam = load_raw_images_per_camera(
|
||||
ep,
|
||||
[
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
],
|
||||
)
|
||||
|
||||
return imgs_per_cam, state, action, velocity, effort
|
||||
|
||||
|
||||
def populate_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
hdf5_files: list[Path],
|
||||
task: str,
|
||||
episodes: list[int] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
if episodes is None:
|
||||
episodes = range(len(hdf5_files))
|
||||
|
||||
for ep_idx in tqdm.tqdm(episodes):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
|
||||
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||
num_frames = state.shape[0]
|
||||
|
||||
for i in range(num_frames):
|
||||
frame = {
|
||||
"observation.state": state[i],
|
||||
"action": action[i],
|
||||
}
|
||||
|
||||
for camera, img_array in imgs_per_cam.items():
|
||||
frame[f"observation.images.{camera}"] = img_array[i]
|
||||
|
||||
if velocity is not None:
|
||||
frame["observation.velocity"] = velocity[i]
|
||||
if effort is not None:
|
||||
frame["observation.effort"] = effort[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=task)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_aloha(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
raw_repo_id: str | None = None,
|
||||
task: str = "DEBUG",
|
||||
*,
|
||||
episodes: list[int] | None = None,
|
||||
push_to_hub: bool = True,
|
||||
is_mobile: bool = False,
|
||||
mode: Literal["video", "image"] = "image",
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
):
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
if raw_repo_id is None:
|
||||
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
||||
download_raw(raw_dir, repo_id=raw_repo_id)
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
|
||||
dataset = create_empty_dataset(
|
||||
repo_id,
|
||||
robot_type="mobile_aloha" if is_mobile else "aloha",
|
||||
mode=mode,
|
||||
has_effort=has_effort(hdf5_files),
|
||||
has_velocity=has_velocity(hdf5_files),
|
||||
dataset_config=dataset_config,
|
||||
)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
hdf5_files,
|
||||
task=task,
|
||||
episodes=episodes,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(port_aloha)
|
||||
57
examples/aloha_real_lyt/env.py
Executable file
57
examples/aloha_real_lyt/env.py
Executable file
@@ -0,0 +1,57 @@
|
||||
from typing import List, Optional # noqa: UP035
|
||||
|
||||
import einops
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
|
||||
render_height: int = 224,
|
||||
render_width: int = 224,
|
||||
) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
for cam_name in obs["images"]:
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
|
||||
)
|
||||
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
|
||||
|
||||
return {
|
||||
"state": obs["qpos"],
|
||||
"images": obs["images"],
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["actions"])
|
||||
51
examples/aloha_real_lyt/main.py
Executable file
51
examples/aloha_real_lyt/main.py
Executable file
@@ -0,0 +1,51 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real_lyt import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "172.20.103.171"
|
||||
port: int = 8090
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
num_episodes: int = 1
|
||||
max_episode_steps: int = 1000
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
||||
|
||||
metadata = ws_client_policy.get_server_metadata()
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=ws_client_policy,
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
num_episodes=args.num_episodes,
|
||||
max_episode_steps=args.max_episode_steps,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
171
examples/aloha_real_lyt/real_env.py
Executable file
171
examples/aloha_real_lyt/real_env.py
Executable file
@@ -0,0 +1,171 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
from typing import Optional, List
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
# This is the reset position that is used by the standard Aloha runtime.
|
||||
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
|
||||
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
|
||||
18
examples/aloha_real_lyt/requirements.in
Executable file
18
examples/aloha_real_lyt/requirements.in
Executable file
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
examples/aloha_real_lyt/requirements.txt
Executable file
156
examples/aloha_real_lyt/requirements.txt
Executable file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
examples/aloha_real_lyt/robot_utils.py
Executable file
275
examples/aloha_real_lyt/robot_utils.py
Executable file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
36
examples/aloha_real_lyt/video_display.py
Executable file
36
examples/aloha_real_lyt/video_display.py
Executable file
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
0
examples/aloha_sim/Dockerfile
Normal file → Executable file
0
examples/aloha_sim/Dockerfile
Normal file → Executable file
0
examples/aloha_sim/README.md
Normal file → Executable file
0
examples/aloha_sim/README.md
Normal file → Executable file
0
examples/aloha_sim/compose.yml
Normal file → Executable file
0
examples/aloha_sim/compose.yml
Normal file → Executable file
0
examples/aloha_sim/env.py
Normal file → Executable file
0
examples/aloha_sim/env.py
Normal file → Executable file
0
examples/aloha_sim/main.py
Normal file → Executable file
0
examples/aloha_sim/main.py
Normal file → Executable file
0
examples/aloha_sim/requirements.in
Normal file → Executable file
0
examples/aloha_sim/requirements.in
Normal file → Executable file
0
examples/aloha_sim/requirements.txt
Normal file → Executable file
0
examples/aloha_sim/requirements.txt
Normal file → Executable file
0
examples/aloha_sim/saver.py
Normal file → Executable file
0
examples/aloha_sim/saver.py
Normal file → Executable file
2
examples/droid/README.md
Normal file → Executable file
2
examples/droid/README.md
Normal file → Executable file
@@ -27,7 +27,7 @@ uv run scripts/serve_policy.py --env=DROID
|
||||
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
||||
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
||||
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
||||
|
||||
```bash
|
||||
|
||||
11
examples/droid/main.py
Normal file → Executable file
11
examples/droid/main.py
Normal file → Executable file
@@ -6,7 +6,7 @@ import datetime
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
|
||||
import time
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
@@ -19,6 +19,9 @@ import tyro
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
# DROID data collection frequency -- we slow down execution to match this frequency
|
||||
DROID_CONTROL_FREQUENCY = 15
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
@@ -95,6 +98,7 @@ def main(args: Args):
|
||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||
print("Running rollout... press Ctrl+C to stop early.")
|
||||
for t_step in bar:
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Get the current observation
|
||||
curr_obs = _extract_observation(
|
||||
@@ -145,6 +149,11 @@ def main(args: Args):
|
||||
action = np.clip(action, -1, 1)
|
||||
|
||||
env.step(action)
|
||||
|
||||
# Sleep to match DROID data collection frequency
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
||||
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
|
||||
56
examples/inference.ipynb
Normal file → Executable file
56
examples/inference.ipynb
Normal file → Executable file
@@ -6,6 +6,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ['HF_ENDPOINT'] = \"https://hf-mirror.com\"\n",
|
||||
"import dataclasses\n",
|
||||
"\n",
|
||||
"import jax\n",
|
||||
@@ -18,6 +20,13 @@
|
||||
"from openpi.training import data_loader as _data_loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -31,10 +40,53 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "fa8d45bf6fe5420f8b152ff52794ee45",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0.00/11.2G [00:00<?, ?iB/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
|
||||
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
|
||||
"WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 18\n",
|
||||
"Some kwargs in processor config are unused and will not have any effect: action_dim, scale, time_horizon, vocab_size, min_token. \n",
|
||||
"Some kwargs in processor config are unused and will not have any effect: action_dim, scale, time_horizon, vocab_size, min_token. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "ValueError",
|
||||
"evalue": "quantile stats must be provided if use_quantile_norm is True. Key actions is missing q01 or q99.",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m checkpoint_dir \u001b[38;5;241m=\u001b[39m download\u001b[38;5;241m.\u001b[39mmaybe_download(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ms3://openpi-assets/checkpoints/pi0_base\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Create a trained policy.\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[43m_policy_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_trained_policy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m example \u001b[38;5;241m=\u001b[39m droid_policy\u001b[38;5;241m.\u001b[39mmake_droid_example()\n",
|
||||
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/policies/policy_config.py:72\u001b[0m, in \u001b[0;36mcreate_trained_policy\u001b[0;34m(train_config, checkpoint_dir, repack_transforms, sample_kwargs, default_prompt, norm_stats)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAsset id is required to load norm stats.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 64\u001b[0m norm_stats \u001b[38;5;241m=\u001b[39m _checkpoints\u001b[38;5;241m.\u001b[39mload_norm_stats(checkpoint_dir \u001b[38;5;241m/\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massets\u001b[39m\u001b[38;5;124m\"\u001b[39m, data_config\u001b[38;5;241m.\u001b[39masset_id)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _policy\u001b[38;5;241m.\u001b[39mPolicy(\n\u001b[1;32m 67\u001b[0m model,\n\u001b[1;32m 68\u001b[0m transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 69\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 70\u001b[0m transforms\u001b[38;5;241m.\u001b[39mInjectDefaultPrompt(default_prompt),\n\u001b[1;32m 71\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[0;32m---> 72\u001b[0m \u001b[43mtransforms\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mNormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnorm_stats\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_quantiles\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_quantile_norm\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 73\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39minputs,\n\u001b[1;32m 74\u001b[0m ],\n\u001b[1;32m 75\u001b[0m output_transforms\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 76\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mmodel_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 77\u001b[0m transforms\u001b[38;5;241m.\u001b[39mUnnormalize(norm_stats, use_quantiles\u001b[38;5;241m=\u001b[39mdata_config\u001b[38;5;241m.\u001b[39muse_quantile_norm),\n\u001b[1;32m 78\u001b[0m \u001b[38;5;241m*\u001b[39mdata_config\u001b[38;5;241m.\u001b[39mdata_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 79\u001b[0m \u001b[38;5;241m*\u001b[39mrepack_transforms\u001b[38;5;241m.\u001b[39moutputs,\n\u001b[1;32m 80\u001b[0m ],\n\u001b[1;32m 81\u001b[0m sample_kwargs\u001b[38;5;241m=\u001b[39msample_kwargs,\n\u001b[1;32m 82\u001b[0m metadata\u001b[38;5;241m=\u001b[39mtrain_config\u001b[38;5;241m.\u001b[39mpolicy_metadata,\n\u001b[1;32m 83\u001b[0m )\n",
|
||||
"File \u001b[0;32m<string>:6\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, norm_stats, use_quantiles, strict)\u001b[0m\n",
|
||||
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:124\u001b[0m, in \u001b[0;36mNormalize.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__post_init__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm_stats \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_quantiles:\n\u001b[0;32m--> 124\u001b[0m \u001b[43m_assert_quantile_stats\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_stats\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/LYT/lerobot_aloha/openpi/src/openpi/transforms.py:431\u001b[0m, in \u001b[0;36m_assert_quantile_stats\u001b[0;34m(norm_stats)\u001b[0m\n\u001b[1;32m 429\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m flatten_dict(norm_stats)\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq01 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m v\u001b[38;5;241m.\u001b[39mq99 \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 431\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 432\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantile stats must be provided if use_quantile_norm is True. Key \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is missing q01 or q99.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 433\u001b[0m )\n",
|
||||
"\u001b[0;31mValueError\u001b[0m: quantile stats must be provided if use_quantile_norm is True. Key actions is missing q01 or q99."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
||||
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
||||
"# checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_base\")\n",
|
||||
"\n",
|
||||
"# Create a trained policy.\n",
|
||||
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
||||
@@ -129,7 +181,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
0
examples/libero/Dockerfile
Normal file → Executable file
0
examples/libero/Dockerfile
Normal file → Executable file
0
examples/libero/README.md
Normal file → Executable file
0
examples/libero/README.md
Normal file → Executable file
0
examples/libero/compose.yml
Normal file → Executable file
0
examples/libero/compose.yml
Normal file → Executable file
0
examples/libero/convert_libero_data_to_lerobot.py
Normal file → Executable file
0
examples/libero/convert_libero_data_to_lerobot.py
Normal file → Executable file
0
examples/libero/main.py
Normal file → Executable file
0
examples/libero/main.py
Normal file → Executable file
0
examples/libero/requirements.in
Normal file → Executable file
0
examples/libero/requirements.in
Normal file → Executable file
0
examples/libero/requirements.txt
Normal file → Executable file
0
examples/libero/requirements.txt
Normal file → Executable file
32
examples/lyt_simple_client/Dockerfile
Executable file
32
examples/lyt_simple_client/Dockerfile
Executable file
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
|
||||
30
examples/lyt_simple_client/README.md
Executable file
30
examples/lyt_simple_client/README.md
Executable file
@@ -0,0 +1,30 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --env DROID
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
353
examples/lyt_simple_client/agilex_utils.py
Normal file
353
examples/lyt_simple_client/agilex_utils.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from collections import deque
|
||||
|
||||
import rospy
|
||||
from std_msgs.msg import Header
|
||||
from geometry_msgs.msg import Twist
|
||||
from sensor_msgs.msg import JointState, Image
|
||||
from nav_msgs.msg import Odometry
|
||||
from cv_bridge import CvBridge
|
||||
import threading
|
||||
|
||||
|
||||
class RosOperator:
|
||||
def __init__(self, args):
|
||||
self.robot_base_deque = None
|
||||
self.puppet_arm_right_deque = None
|
||||
self.puppet_arm_left_deque = None
|
||||
self.img_front_deque = None
|
||||
self.img_right_deque = None
|
||||
self.img_left_deque = None
|
||||
self.img_front_depth_deque = None
|
||||
self.img_right_depth_deque = None
|
||||
self.img_left_depth_deque = None
|
||||
self.bridge = None
|
||||
self.puppet_arm_left_publisher = None
|
||||
self.puppet_arm_right_publisher = None
|
||||
self.robot_base_publisher = None
|
||||
self.puppet_arm_publish_thread = None
|
||||
self.puppet_arm_publish_lock = None
|
||||
self.args = args
|
||||
self.ctrl_state = False
|
||||
self.ctrl_state_lock = threading.Lock()
|
||||
self.init()
|
||||
self.init_ros()
|
||||
|
||||
def init(self):
|
||||
self.bridge = CvBridge()
|
||||
self.img_left_deque = deque()
|
||||
self.img_right_deque = deque()
|
||||
self.img_front_deque = deque()
|
||||
self.img_left_depth_deque = deque()
|
||||
self.img_right_depth_deque = deque()
|
||||
self.img_front_depth_deque = deque()
|
||||
self.puppet_arm_left_deque = deque()
|
||||
self.puppet_arm_right_deque = deque()
|
||||
self.robot_base_deque = deque()
|
||||
self.puppet_arm_publish_lock = threading.Lock()
|
||||
self.puppet_arm_publish_lock.acquire()
|
||||
|
||||
def puppet_arm_publish(self, left, right):
|
||||
# 默认速度和力矩值
|
||||
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
|
||||
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.03296661376953125]
|
||||
|
||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
|
||||
0.1527481079101562, -0.013187408447265625, -0.013187408447265625,
|
||||
0.0, -0.03076934814453125, -0.3296699523925781, 0.43956756591797,
|
||||
0.5208797454833984, -0.11868095397949219, 0.03956031799316406, 0.0]
|
||||
# 修正位置
|
||||
left[-1] *= 12
|
||||
right[-1] *= 12
|
||||
# 始终为正数,小于0的裁剪为0
|
||||
left[-1] = max(left[-1], 0)
|
||||
right[-1] = max(right[-1], 0)
|
||||
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
joint_state_msg.position = left
|
||||
joint_state_msg.velocity = last_velocity[:7]
|
||||
joint_state_msg.effort = last_effort[:7]
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = right
|
||||
joint_state_msg.velocity = last_velocity[7:]
|
||||
joint_state_msg.effort = last_effort[7:]
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
|
||||
def robot_base_publish(self, vel):
|
||||
vel_msg = Twist()
|
||||
vel_msg.linear.x = vel[0]
|
||||
vel_msg.linear.y = 0
|
||||
vel_msg.linear.z = 0
|
||||
vel_msg.angular.x = 0
|
||||
vel_msg.angular.y = 0
|
||||
vel_msg.angular.z = vel[1]
|
||||
self.robot_base_publisher.publish(vel_msg)
|
||||
|
||||
def puppet_arm_publish_continuous(self, left, right):
|
||||
rate = rospy.Rate(self.args.publish_rate)
|
||||
left_arm = None
|
||||
right_arm = None
|
||||
while True and not rospy.is_shutdown():
|
||||
if len(self.puppet_arm_left_deque) != 0:
|
||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||
if len(self.puppet_arm_right_deque) != 0:
|
||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||
if left_arm is None or right_arm is None:
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
break
|
||||
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
||||
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
||||
flag = True
|
||||
step = 0
|
||||
while flag and not rospy.is_shutdown():
|
||||
if self.puppet_arm_publish_lock.acquire(False):
|
||||
return
|
||||
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
||||
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
||||
flag = False
|
||||
for i in range(len(left)):
|
||||
if left_diff[i] < self.args.arm_steps_length[i]:
|
||||
left_arm[i] = left[i]
|
||||
else:
|
||||
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
||||
flag = True
|
||||
for i in range(len(right)):
|
||||
if right_diff[i] < self.args.arm_steps_length[i]:
|
||||
right_arm[i] = right[i]
|
||||
else:
|
||||
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
||||
flag = True
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
joint_state_msg.position = left_arm
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = right_arm
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
step += 1
|
||||
print("puppet_arm_publish_continuous:", step)
|
||||
rate.sleep()
|
||||
|
||||
def puppet_arm_publish_linear(self, left, right):
|
||||
num_step = 100
|
||||
rate = rospy.Rate(200)
|
||||
|
||||
left_arm = None
|
||||
right_arm = None
|
||||
|
||||
while True and not rospy.is_shutdown():
|
||||
if len(self.puppet_arm_left_deque) != 0:
|
||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||
if len(self.puppet_arm_right_deque) != 0:
|
||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||
if left_arm is None or right_arm is None:
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
traj_left_list = np.linspace(left_arm, left, num_step)
|
||||
traj_right_list = np.linspace(right_arm, right, num_step)
|
||||
|
||||
for i in range(len(traj_left_list)):
|
||||
traj_left = traj_left_list[i]
|
||||
traj_right = traj_right_list[i]
|
||||
traj_left[-1] = left[-1]
|
||||
traj_right[-1] = right[-1]
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
joint_state_msg.position = traj_left
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = traj_right
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
rate.sleep()
|
||||
|
||||
def puppet_arm_publish_continuous_thread(self, left, right):
|
||||
if self.puppet_arm_publish_thread is not None:
|
||||
self.puppet_arm_publish_lock.release()
|
||||
self.puppet_arm_publish_thread.join()
|
||||
self.puppet_arm_publish_lock.acquire(False)
|
||||
self.puppet_arm_publish_thread = None
|
||||
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
||||
self.puppet_arm_publish_thread.start()
|
||||
|
||||
def get_frame(self):
|
||||
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
|
||||
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
|
||||
return False
|
||||
if self.args.use_depth_image:
|
||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
|
||||
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
|
||||
else:
|
||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
|
||||
|
||||
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
|
||||
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_deque.popleft()
|
||||
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_deque.popleft()
|
||||
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_deque.popleft()
|
||||
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
||||
|
||||
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
||||
|
||||
img_left_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_depth_deque.popleft()
|
||||
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
|
||||
|
||||
img_right_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_depth_deque.popleft()
|
||||
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
|
||||
|
||||
img_front_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_depth_deque.popleft()
|
||||
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
|
||||
|
||||
robot_base = None
|
||||
if self.args.use_robot_base:
|
||||
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.robot_base_deque.popleft()
|
||||
robot_base = self.robot_base_deque.popleft()
|
||||
|
||||
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||
puppet_arm_left, puppet_arm_right, robot_base)
|
||||
|
||||
def img_left_callback(self, msg):
|
||||
if len(self.img_left_deque) >= 2000:
|
||||
self.img_left_deque.popleft()
|
||||
self.img_left_deque.append(msg)
|
||||
|
||||
def img_right_callback(self, msg):
|
||||
if len(self.img_right_deque) >= 2000:
|
||||
self.img_right_deque.popleft()
|
||||
self.img_right_deque.append(msg)
|
||||
|
||||
def img_front_callback(self, msg):
|
||||
if len(self.img_front_deque) >= 2000:
|
||||
self.img_front_deque.popleft()
|
||||
self.img_front_deque.append(msg)
|
||||
|
||||
def img_left_depth_callback(self, msg):
|
||||
if len(self.img_left_depth_deque) >= 2000:
|
||||
self.img_left_depth_deque.popleft()
|
||||
self.img_left_depth_deque.append(msg)
|
||||
|
||||
def img_right_depth_callback(self, msg):
|
||||
if len(self.img_right_depth_deque) >= 2000:
|
||||
self.img_right_depth_deque.popleft()
|
||||
self.img_right_depth_deque.append(msg)
|
||||
|
||||
def img_front_depth_callback(self, msg):
|
||||
if len(self.img_front_depth_deque) >= 2000:
|
||||
self.img_front_depth_deque.popleft()
|
||||
self.img_front_depth_deque.append(msg)
|
||||
|
||||
def puppet_arm_left_callback(self, msg):
|
||||
if len(self.puppet_arm_left_deque) >= 2000:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
self.puppet_arm_left_deque.append(msg)
|
||||
|
||||
def puppet_arm_right_callback(self, msg):
|
||||
if len(self.puppet_arm_right_deque) >= 2000:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
self.puppet_arm_right_deque.append(msg)
|
||||
|
||||
def robot_base_callback(self, msg):
|
||||
if len(self.robot_base_deque) >= 2000:
|
||||
self.robot_base_deque.popleft()
|
||||
self.robot_base_deque.append(msg)
|
||||
|
||||
def ctrl_callback(self, msg):
|
||||
self.ctrl_state_lock.acquire()
|
||||
self.ctrl_state = msg.data
|
||||
self.ctrl_state_lock.release()
|
||||
|
||||
def get_ctrl_state(self):
|
||||
self.ctrl_state_lock.acquire()
|
||||
state = self.ctrl_state
|
||||
self.ctrl_state_lock.release()
|
||||
return state
|
||||
|
||||
def init_ros(self):
|
||||
rospy.init_node('joint_state_publisher', anonymous=True)
|
||||
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
|
||||
if self.args.use_depth_image:
|
||||
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||
# rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
|
||||
# self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
||||
# self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
|
||||
# self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.img_left_topic = '/camera_l/color/image_raw'
|
||||
args.img_right_topic = '/camera_r/color/image_raw'
|
||||
args.img_front_topic = '/camera_f/color/image_raw'
|
||||
|
||||
args.puppet_arm_left_cmd_topic = '/master/joint_left'
|
||||
args.puppet_arm_right_cmd_topic = '/master/joint_right'
|
||||
args.puppet_arm_left_topic = '/puppet/joint_left'
|
||||
args.puppet_arm_right_topic = '/puppet/joint_right'
|
||||
|
||||
args.publish_rate = 30
|
||||
args.use_robot_base = False
|
||||
args.use_actions_interpolation = False
|
||||
args.use_depth_image = False
|
||||
a = RosOperator(args)
|
||||
print(a)
|
||||
42
examples/lyt_simple_client/compose.yml
Executable file
42
examples/lyt_simple_client/compose.yml
Executable file
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
206
examples/lyt_simple_client/main.py
Executable file
206
examples/lyt_simple_client/main.py
Executable file
@@ -0,0 +1,206 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tyro
|
||||
import rospy
|
||||
from std_msgs.msg import Header
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from agilex_utils import RosOperator
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
AGILEX_ALOHA = "agilex_arx_3camera_aloha"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "172.20.103.171"
|
||||
port: int = 8090
|
||||
|
||||
env: EnvMode = EnvMode.AGILEX_ALOHA
|
||||
num_steps: int = 10
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
EnvMode.ALOHA: _random_observation_aloha,
|
||||
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
||||
EnvMode.DROID: _random_observation_droid,
|
||||
EnvMode.LIBERO: _random_observation_libero,
|
||||
EnvMode.AGILEX_ALOHA: observation_agilex_3camera_aloha,
|
||||
}[args.env]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {policy.get_server_metadata()}")
|
||||
|
||||
args_ros, ros_operator = init_agilex_3camera_aloha()
|
||||
|
||||
# Send 1 observation to make sure the model is loaded.
|
||||
policy.infer(obs_fn(args_ros, ros_operator))
|
||||
|
||||
# test inference
|
||||
start = time.time()
|
||||
for _ in range(10):
|
||||
policy.infer(obs_fn(args_ros, ros_operator))
|
||||
end = time.time()
|
||||
|
||||
print(f"Total time taken: {end - start:.2f} s")
|
||||
print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
|
||||
if 1000 * (end - start) / args.num_steps < 500:
|
||||
logging.info("Inference time is less than 0.5 second! Its good!")
|
||||
else:
|
||||
logging.warning("Inference time is more than 0.5 second! Its bad!")
|
||||
|
||||
|
||||
# pub
|
||||
master_arm_left_publisher = rospy.Publisher(args_ros.master_arm_left_topic, JointState, queue_size=10)
|
||||
master_arm_right_publisher = rospy.Publisher(args_ros.master_arm_right_topic, JointState, queue_size=10)
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
rate = rospy.Rate(30)
|
||||
# 默认速度和力矩值
|
||||
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
|
||||
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.03296661376953125]
|
||||
|
||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
|
||||
0.6527481079101562, -0.013187408447265625, -0.013187408447265625,
|
||||
0.0, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.03296661376953125]
|
||||
while True:
|
||||
actions = policy.infer(obs_fn(args_ros, ros_operator))['actions']
|
||||
for idx, action in enumerate(actions):
|
||||
if(rospy.is_shutdown()):
|
||||
break
|
||||
# print(action)
|
||||
print(idx, np.round(action[:7], 4))
|
||||
cur_timestamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.header.stamp = cur_timestamp
|
||||
|
||||
joint_state_msg.position = action[:7]
|
||||
joint_state_msg.velocity = last_velocity[:7]
|
||||
joint_state_msg.effort = last_effort[:7]
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
master_arm_left_publisher.publish(joint_state_msg)
|
||||
|
||||
joint_state_msg.position = action[7:]
|
||||
joint_state_msg.velocity = last_velocity[7:]
|
||||
joint_state_msg.effort = last_effort[7:]
|
||||
master_arm_right_publisher.publish(joint_state_msg)
|
||||
if(rospy.is_shutdown()):
|
||||
break
|
||||
rate.sleep()
|
||||
|
||||
|
||||
|
||||
def init_agilex_3camera_aloha():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.img_left_topic = '/camera_l/color/image_raw'
|
||||
args.img_right_topic = '/camera_r/color/image_raw'
|
||||
args.img_front_topic = '/camera_f/color/image_raw'
|
||||
|
||||
args.master_arm_left_topic = '/master/joint_left'
|
||||
args.master_arm_right_topic = '/master/joint_right'
|
||||
args.puppet_arm_left_topic = '/puppet/joint_left'
|
||||
args.puppet_arm_right_topic = '/puppet/joint_right'
|
||||
|
||||
args.publish_rate = 30
|
||||
args.use_robot_base = False
|
||||
args.use_actions_interpolation = False
|
||||
args.use_depth_image = False
|
||||
|
||||
ros_operator = RosOperator(args)
|
||||
return args, ros_operator
|
||||
|
||||
def observation_agilex_3camera_aloha(args, ros_operator: RosOperator):
|
||||
print_flag = True
|
||||
rate = rospy.Rate(args.publish_rate)
|
||||
while True and not rospy.is_shutdown():
|
||||
result = ros_operator.get_frame()
|
||||
if not result:
|
||||
if print_flag:
|
||||
print("syn fail")
|
||||
print_flag = False
|
||||
rate.sleep()
|
||||
continue
|
||||
print_flag = True
|
||||
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||
puppet_arm_left, puppet_arm_right, robot_base) = result
|
||||
break
|
||||
|
||||
state = np.concatenate([
|
||||
puppet_arm_left.position, puppet_arm_right.position
|
||||
])
|
||||
# a = np.random.randint(256, size=(3, 224, 224), dtype=np.uint8)
|
||||
img_front = np.transpose(img_front, (2, 0, 1))
|
||||
img_left = np.transpose(img_left, (2, 0, 1))
|
||||
img_right = np.transpose(img_right, (2, 0, 1))
|
||||
return {
|
||||
"state": state,
|
||||
"images": {
|
||||
"cam_high": img_front,
|
||||
"cam_left_wrist": img_left,
|
||||
"cam_right_wrist": img_right,
|
||||
},
|
||||
"prompt": "weigh a reagent by a balance",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main(tyro.cli(Args))
|
||||
# args, ros_operator = init_agilex_3camera_aloha()
|
||||
# observation_agilex_3camera_aloha(args, ros_operator)
|
||||
# print()
|
||||
2
examples/lyt_simple_client/requirements.in
Executable file
2
examples/lyt_simple_client/requirements.in
Executable file
@@ -0,0 +1,2 @@
|
||||
numpy
|
||||
tyro
|
||||
27
examples/lyt_simple_client/requirements.txt
Executable file
27
examples/lyt_simple_client/requirements.txt
Executable file
@@ -0,0 +1,27 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
|
||||
backports-cached-property==1.0.2
|
||||
# via tyro
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
eval-type-backport==0.1.3
|
||||
# via tyro
|
||||
markdown-it-py==2.2.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.21.6
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.17.2
|
||||
# via rich
|
||||
rich==13.8.1
|
||||
# via tyro
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
typing-extensions==4.7.1
|
||||
# via
|
||||
# markdown-it-py
|
||||
# rich
|
||||
# tyro
|
||||
tyro==0.9.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
0
examples/policy_records.ipynb
Normal file → Executable file
0
examples/policy_records.ipynb
Normal file → Executable file
0
examples/simple_client/Dockerfile
Normal file → Executable file
0
examples/simple_client/Dockerfile
Normal file → Executable file
4
examples/simple_client/README.md
Normal file → Executable file
4
examples/simple_client/README.md
Normal file → Executable file
@@ -2,7 +2,7 @@
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specifiy which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
@@ -27,4 +27,4 @@ Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
```
|
||||
|
||||
0
examples/simple_client/compose.yml
Normal file → Executable file
0
examples/simple_client/compose.yml
Normal file → Executable file
0
examples/simple_client/main.py
Normal file → Executable file
0
examples/simple_client/main.py
Normal file → Executable file
0
examples/simple_client/requirements.in
Normal file → Executable file
0
examples/simple_client/requirements.in
Normal file → Executable file
0
examples/simple_client/requirements.txt
Normal file → Executable file
0
examples/simple_client/requirements.txt
Normal file → Executable file
151
examples/ur5/README.md
Executable file
151
examples/ur5/README.md
Executable file
@@ -0,0 +1,151 @@
|
||||
# UR5 Example
|
||||
|
||||
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
|
||||
|
||||
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Inputs(transforms.DataTransformFn):
|
||||
|
||||
action_dim: int
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
mask_padding = self.model_type == _model.ModelType.PI0
|
||||
|
||||
# First, concatenate the joints and gripper into the state vector.
|
||||
# Pad to the expected input dimensionality of the model (same as action_dim).
|
||||
state = np.concatenate([data["joints"], data["gripper"]])
|
||||
state = transforms.pad_to_dim(state, self.action_dim)
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference.
|
||||
base_image = _parse_image(data["base_rgb"])
|
||||
wrist_image = _parse_image(data["wrist_rgb"])
|
||||
|
||||
# Create inputs dict.
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": base_image,
|
||||
"left_wrist_0_rgb": wrist_image,
|
||||
# Since there is no right wrist, replace with zeros
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.True_,
|
||||
"left_wrist_0_rgb": np.True_,
|
||||
# Since the "slot" for the right wrist is not used, this mask is set
|
||||
# to False
|
||||
"right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
|
||||
},
|
||||
}
|
||||
|
||||
# Pad actions to the model action dimension.
|
||||
if "actions" in data:
|
||||
# The robot produces 7D actions (6 DoF + 1 gripper), and we pad these.
|
||||
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
|
||||
inputs["actions"] = actions
|
||||
|
||||
# Pass the prompt (aka language instruction) to the model.
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Outputs(transforms.DataTransformFn):
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
|
||||
return {"actions": np.asarray(data["actions"][:, :7])}
|
||||
|
||||
```
|
||||
|
||||
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LeRobotUR5DataConfig(DataConfigFactory):
|
||||
|
||||
@override
|
||||
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
||||
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
|
||||
repack_transform = _transforms.Group(
|
||||
inputs=[
|
||||
_transforms.RepackTransform(
|
||||
{
|
||||
"base_rgb": "image",
|
||||
"wrist_rgb": "wrist_image",
|
||||
"joints": "joints",
|
||||
"gripper": "gripper",
|
||||
"prompt": "prompt",
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# These transforms are the ones we wrote earlier.
|
||||
data_transforms = _transforms.Group(
|
||||
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
|
||||
outputs=[UR5Outputs()],
|
||||
)
|
||||
|
||||
# Convert absolute actions to delta actions.
|
||||
# By convention, we do not convert the gripper action (7th dimension).
|
||||
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
||||
data_transforms = data_transforms.push(
|
||||
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
||||
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
||||
)
|
||||
|
||||
# Model transforms include things like tokenizing the prompt and action targets
|
||||
# You do not need to change anything here for your own dataset.
|
||||
model_transforms = ModelTransformFactory()(model_config)
|
||||
|
||||
# We return all data transforms for training and inference. No need to change anything here.
|
||||
return dataclasses.replace(
|
||||
self.create_base_config(assets_dirs),
|
||||
repack_transforms=repack_transform,
|
||||
data_transforms=data_transforms,
|
||||
model_transforms=model_transforms,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name="pi0_ur5",
|
||||
model=pi0.Pi0Config(),
|
||||
data=LeRobotUR5DataConfig(
|
||||
repo_id="your_username/ur5_dataset",
|
||||
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
|
||||
# Reloading normalization stats can help transfer pre-trained models to new environments.
|
||||
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
|
||||
assets=AssetsConfig(
|
||||
assets_dir="s3://openpi-assets/checkpoints/pi0_base/assets",
|
||||
asset_id="ur5e",
|
||||
),
|
||||
base_config=DataConfig(
|
||||
local_files_only=True, # True, if dataset is saved locally.
|
||||
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
||||
# ``task`` field in the LeRobot dataset. The recommended setting is True.
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
# Load the pi0 base model checkpoint.
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
|
||||
num_train_steps=30_000,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
0
packages/openpi-client/pyproject.toml
Normal file → Executable file
0
packages/openpi-client/pyproject.toml
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/__init__.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/__init__.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/action_chunk_broker.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/action_chunk_broker.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/base_policy.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/base_policy.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/image_tools.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/image_tools.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/environment.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/environment.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/subscriber.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/runtime/subscriber.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/websocket_client_policy.py
Normal file → Executable file
0
packages/openpi-client/src/openpi_client/websocket_client_policy.py
Normal file → Executable file
5
pyproject.toml
Normal file → Executable file
5
pyproject.toml
Normal file → Executable file
@@ -3,7 +3,7 @@ name = "openpi"
|
||||
version = "0.1.0"
|
||||
description = "Physical Intelligence open source repo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
requires-python = ">=3.10"
|
||||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"augmax>=0.3.4",
|
||||
@@ -21,7 +21,6 @@ dependencies = [
|
||||
"ml_collections==1.0.0",
|
||||
"numpy>=1.26.4",
|
||||
"numpydantic>=1.6.6",
|
||||
"opencv-python>=4.10.0.84",
|
||||
"openpi-client",
|
||||
"orbax-checkpoint==0.11.1",
|
||||
"pillow>=11.0.0",
|
||||
@@ -65,7 +64,7 @@ members = ["packages/*"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py311"
|
||||
target-version = "py310"
|
||||
extend-exclude = ["docker", "third_party"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
0
scripts/__init__.py
Normal file → Executable file
0
scripts/__init__.py
Normal file → Executable file
0
scripts/compute_norm_stats.py
Normal file → Executable file
0
scripts/compute_norm_stats.py
Normal file → Executable file
4
scripts/docker/compose.yml
Normal file → Executable file
4
scripts/docker/compose.yml
Normal file → Executable file
@@ -1,10 +1,10 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/compose.yml up --build
|
||||
# docker compose -f scripts/docker/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ..
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
|
||||
0
scripts/docker/serve_policy.Dockerfile
Normal file → Executable file
0
scripts/docker/serve_policy.Dockerfile
Normal file → Executable file
0
scripts/serve_policy.py
Normal file → Executable file
0
scripts/serve_policy.py
Normal file → Executable file
1
scripts/train.py
Normal file → Executable file
1
scripts/train.py
Normal file → Executable file
@@ -199,7 +199,6 @@ def main(config: _config.TrainConfig):
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
|
||||
0
scripts/train_test.py
Normal file → Executable file
0
scripts/train_test.py
Normal file → Executable file
0
src/openpi/__init__.py
Normal file → Executable file
0
src/openpi/__init__.py
Normal file → Executable file
0
src/openpi/conftest.py
Normal file → Executable file
0
src/openpi/conftest.py
Normal file → Executable file
0
src/openpi/models/__init__.py
Normal file → Executable file
0
src/openpi/models/__init__.py
Normal file → Executable file
0
src/openpi/models/gemma.py
Normal file → Executable file
0
src/openpi/models/gemma.py
Normal file → Executable file
86
src/openpi/models/gemma_fast.py
Normal file → Executable file
86
src/openpi/models/gemma_fast.py
Normal file → Executable file
@@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif
|
||||
Used for FAST autoregressive policies.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
@@ -25,9 +26,10 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
Variant = Literal["gemma_2b"]
|
||||
Variant = Literal["gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant):
|
||||
@@ -48,6 +50,26 @@ def get_config(variant):
|
||||
"remat_policy": "nothing_saveable",
|
||||
}
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
"lora_configs": {
|
||||
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
},
|
||||
}
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@@ -110,21 +132,34 @@ class Attention(nn.Module):
|
||||
|
||||
cache_dtype: str | None = None
|
||||
|
||||
lora_config: lora.LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
self.qkv_einsum = Einsum(
|
||||
self.qkv_einsum = lora.Einsum(
|
||||
shape=(3, self.num_heads, self.features, self.head_dim),
|
||||
name="qkv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
else:
|
||||
# MQA
|
||||
self.q_einsum = Einsum(
|
||||
self.q_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.features, self.head_dim),
|
||||
name="q_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.kv_einsum = Einsum(
|
||||
self.kv_einsum = lora.Einsum(
|
||||
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
||||
name="kv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.attn_vec_einsum = Einsum(
|
||||
self.attn_vec_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.head_dim, self.features),
|
||||
name="attn_vec_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
|
||||
def _init_cache(self, k, v, cache_size):
|
||||
@@ -189,37 +224,6 @@ class Attention(nn.Module):
|
||||
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.zeros_init(),
|
||||
((2, self.features, self.hidden_dim)),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.zeros_init(),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
@@ -233,6 +237,7 @@ class Block(nn.Module):
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
cache_dtype: str | None = None
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
def setup(self):
|
||||
self.pre_attention_norm = RMSNorm()
|
||||
@@ -242,9 +247,12 @@ class Block(nn.Module):
|
||||
features=self.embed_dim,
|
||||
head_dim=self.head_dim,
|
||||
cache_dtype=self.cache_dtype,
|
||||
lora_config=self.lora_configs.get("attn"),
|
||||
)
|
||||
self.pre_ffw_norm = RMSNorm()
|
||||
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
|
||||
self.mlp = lora.FeedForward(
|
||||
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
|
||||
)
|
||||
if self.dropout:
|
||||
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
||||
else:
|
||||
@@ -289,6 +297,7 @@ class Module(nn.Module):
|
||||
|
||||
scan: bool = False
|
||||
remat_policy: str = "none"
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
@@ -380,6 +389,7 @@ class Module(nn.Module):
|
||||
"dropout": self.dropout,
|
||||
"dropout_bdims": self.dropout_bdims,
|
||||
"cache_dtype": self.cache_dtype,
|
||||
"lora_configs": self.lora_configs,
|
||||
}
|
||||
layers = self.scope.push("layers")
|
||||
blocks = [
|
||||
|
||||
0
src/openpi/models/lora.py
Normal file → Executable file
0
src/openpi/models/lora.py
Normal file → Executable file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user