Compare commits
21 Commits
user/miche
...
user/fraca
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b5fce823f | ||
|
|
2cce85b5dd | ||
|
|
b2d003e6eb | ||
|
|
200ba1feb5 | ||
|
|
0fc9a4341f | ||
|
|
d40e74f371 | ||
|
|
40237f5ea3 | ||
|
|
2bcdb57854 | ||
|
|
e9ca1b612d | ||
|
|
169babd621 | ||
|
|
a9031ee1be | ||
|
|
fc107a2c6e | ||
|
|
84fabbf4af | ||
|
|
5322417c03 | ||
|
|
4041f57943 | ||
|
|
2c86fea78a | ||
|
|
437fc29e12 | ||
|
|
aee86b4b18 | ||
|
|
1c873df5c0 | ||
|
|
145fe4cd17 | ||
|
|
e004247ed4 |
@@ -36,8 +36,8 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.2
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.31.1
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
@@ -48,7 +48,7 @@ repos:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.10
|
||||
rev: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -57,12 +57,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.0
|
||||
rev: v8.24.2
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.4.1
|
||||
rev: v1.5.2
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
|
||||
@@ -98,14 +98,14 @@ conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, if you don't have `ffmpeg` in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
conda install ffmpeg
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install --no-binary=av -e .
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
@@ -118,7 +118,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[aloha, pusht]"
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
|
||||
@@ -17,12 +17,21 @@
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
|
||||
# see https://rerun.io/docs/howto/visualization/limit-ram
|
||||
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||
|
||||
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
|
||||
rr.init("lerobot_capture_camera_feed")
|
||||
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
|
||||
|
||||
now = dt.datetime.now()
|
||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||
if not capture_dir.exists():
|
||||
@@ -39,24 +48,21 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||
|
||||
frame_index = 0
|
||||
while True:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < duration:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
print("Error: Could not read frame.")
|
||||
break
|
||||
|
||||
cv2.imshow("Video Stream", frame)
|
||||
rr.log("video/stream", rr.Image(frame.numpy()), static=True)
|
||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||
frame_index += 1
|
||||
|
||||
# Break the loop on 'q' key press
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
# Release the capture and destroy all windows
|
||||
# Release the capture
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -86,5 +92,11 @@ if __name__ == "__main__":
|
||||
default=720,
|
||||
help="Height of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Duration in seconds for which the video stream should be captured.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
display_and_save_video_stream(**vars(args))
|
||||
|
||||
@@ -57,9 +57,15 @@ conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
#### 5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
||||
@@ -491,6 +497,9 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
|
||||
@@ -67,9 +67,15 @@ conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
#### 5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## C. Install LeRobot on laptop
|
||||
@@ -108,9 +114,15 @@ conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
#### 5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||
@@ -393,6 +405,10 @@ python lerobot/scripts/control_robot.py \
|
||||
```
|
||||
|
||||
# F. Teleoperate
|
||||
|
||||
> [!TIP]
|
||||
> If you're using a Mac, you might need to give Terminal permission to access your keyboard. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
|
||||
|
||||
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
@@ -408,6 +424,8 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.fps=30
|
||||
```
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port`
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||
| ---------- | ------------------ | ---------------------- |
|
||||
|
||||
@@ -31,9 +31,15 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the feetech motors:
|
||||
5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## Configure the motors
|
||||
@@ -212,6 +218,9 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
|
||||
@@ -119,7 +119,7 @@ print(dataset.features[camera_key]["shape"])
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
@@ -143,6 +143,6 @@ dataloader = torch.utils.data.DataLoader(
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
@@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3
|
||||
|
||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[pusht]"`
|
||||
pip install -e ".[pusht]"
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami
|
||||
|
||||
Using `pip`:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[dynamixel]"
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
Using `poetry`:
|
||||
@@ -55,6 +55,9 @@ Finally, connect both arms to your computer via USB. Note that the USB doesn't p
|
||||
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
||||
|
||||
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
@@ -828,10 +831,10 @@ It contains:
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
||||
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
||||
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
|
||||
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
|
||||
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge`
|
||||
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
|
||||
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:
|
||||
|
||||
@@ -43,14 +43,19 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
6. Install LeRobot with stretch dependencies:
|
||||
6. When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[stretch]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
7. Install LeRobot with stretch dependencies:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[stretch]"
|
||||
```
|
||||
|
||||
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
|
||||
|
||||
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
||||
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
||||
```bash
|
||||
stretch_system_check.py
|
||||
```
|
||||
@@ -97,6 +102,8 @@ This is equivalent to running `stretch_robot_home.py`
|
||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||
|
||||
Now try out teleoperation (see above documentation to learn about the gamepad controls):
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
|
||||
@@ -30,9 +30,14 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||
5. When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
@@ -43,6 +48,9 @@ Teleoperation consists in manually operating the leader arms to move the followe
|
||||
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
|
||||
|
||||
By running the following code, you can start your first **SAFE** teleoperation:
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
|
||||
@@ -1053,7 +1053,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
|
||||
@@ -240,7 +240,7 @@ def load_episodes_stats(local_dir: Path) -> dict:
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
|
||||
@@ -481,7 +481,7 @@ def convert_dataset(
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
|
||||
@@ -13,7 +13,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -86,3 +90,38 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
|
||||
first_type = type(env.envs[0]) # Get type of first env
|
||||
return all(type(e) is first_type for e in env.envs) # Fast type check
|
||||
|
||||
|
||||
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||
|
||||
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
||||
warnings.warn(
|
||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not are_all_envs_same_type(env):
|
||||
warnings.warn(
|
||||
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
elif hasattr(env.envs[0], "task"):
|
||||
observation["task"] = env.call("task")
|
||||
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
|
||||
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||
observation["task"] = ["" for _ in range(num_envs)]
|
||||
return observation
|
||||
|
||||
1
lerobot/common/mocks/__init__.py
Normal file
1
lerobot/common/mocks/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Common mocks for robot devices and testing
|
||||
0
lerobot/common/mocks/cameras/__init__.py
Normal file
0
lerobot/common/mocks/cameras/__init__.py
Normal file
101
lerobot/common/mocks/cameras/mock_cv2.py
Normal file
101
lerobot/common/mocks/cameras/mock_cv2.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import cache
|
||||
|
||||
import numpy as np
|
||||
|
||||
CAP_V4L2 = 200
|
||||
CAP_DSHOW = 700
|
||||
CAP_AVFOUNDATION = 1200
|
||||
CAP_ANY = -1
|
||||
|
||||
CAP_PROP_FPS = 5
|
||||
CAP_PROP_FRAME_WIDTH = 3
|
||||
CAP_PROP_FRAME_HEIGHT = 4
|
||||
COLOR_RGB2BGR = 4
|
||||
COLOR_BGR2RGB = 4
|
||||
|
||||
ROTATE_90_COUNTERCLOCKWISE = 2
|
||||
ROTATE_90_CLOCKWISE = 0
|
||||
ROTATE_180 = 1
|
||||
|
||||
|
||||
@cache
|
||||
def _generate_image(width: int, height: int):
|
||||
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def cvtColor(color_image, color_conversion): # noqa: N802
|
||||
if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
|
||||
return color_image[:, :, [2, 1, 0]]
|
||||
else:
|
||||
raise NotImplementedError(color_conversion)
|
||||
|
||||
|
||||
def rotate(color_image, rotation):
|
||||
if rotation is None:
|
||||
return color_image
|
||||
elif rotation == ROTATE_90_CLOCKWISE:
|
||||
return np.rot90(color_image, k=1)
|
||||
elif rotation == ROTATE_180:
|
||||
return np.rot90(color_image, k=2)
|
||||
elif rotation == ROTATE_90_COUNTERCLOCKWISE:
|
||||
return np.rot90(color_image, k=3)
|
||||
else:
|
||||
raise NotImplementedError(rotation)
|
||||
|
||||
|
||||
class VideoCapture:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._mock_dict = {
|
||||
CAP_PROP_FPS: 30,
|
||||
CAP_PROP_FRAME_WIDTH: 640,
|
||||
CAP_PROP_FRAME_HEIGHT: 480,
|
||||
}
|
||||
self._is_opened = True
|
||||
|
||||
def isOpened(self): # noqa: N802
|
||||
return self._is_opened
|
||||
|
||||
def set(self, propId: int, value: float) -> bool: # noqa: N803
|
||||
if not self._is_opened:
|
||||
raise RuntimeError("Camera is not opened")
|
||||
self._mock_dict[propId] = value
|
||||
return True
|
||||
|
||||
def get(self, propId: int) -> float: # noqa: N803
|
||||
if not self._is_opened:
|
||||
raise RuntimeError("Camera is not opened")
|
||||
value = self._mock_dict[propId]
|
||||
if value == 0:
|
||||
if propId == CAP_PROP_FRAME_HEIGHT:
|
||||
value = 480
|
||||
elif propId == CAP_PROP_FRAME_WIDTH:
|
||||
value = 640
|
||||
return value
|
||||
|
||||
def read(self):
|
||||
if not self._is_opened:
|
||||
raise RuntimeError("Camera is not opened")
|
||||
h = self.get(CAP_PROP_FRAME_HEIGHT)
|
||||
w = self.get(CAP_PROP_FRAME_WIDTH)
|
||||
ret = True
|
||||
return ret, _generate_image(width=w, height=h)
|
||||
|
||||
def release(self):
|
||||
self._is_opened = False
|
||||
|
||||
def __del__(self):
|
||||
if self._is_opened:
|
||||
self.release()
|
||||
148
lerobot/common/mocks/cameras/mock_pyrealsense2.py
Normal file
148
lerobot/common/mocks/cameras/mock_pyrealsense2.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class stream(enum.Enum): # noqa: N801
|
||||
color = 0
|
||||
depth = 1
|
||||
|
||||
|
||||
class format(enum.Enum): # noqa: N801
|
||||
rgb8 = 0
|
||||
z16 = 1
|
||||
|
||||
|
||||
class config: # noqa: N801
|
||||
def enable_device(self, device_id: str):
|
||||
self.device_enabled = device_id
|
||||
|
||||
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
|
||||
self.stream_type = stream_type
|
||||
# Overwrite default values when possible
|
||||
self.width = 848 if width is None else width
|
||||
self.height = 480 if height is None else height
|
||||
self.color_format = format.rgb8 if color_format is None else color_format
|
||||
self.fps = 30 if fps is None else fps
|
||||
|
||||
|
||||
class RSColorProfile:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def fps(self):
|
||||
return self.config.fps
|
||||
|
||||
def width(self):
|
||||
return self.config.width
|
||||
|
||||
def height(self):
|
||||
return self.config.height
|
||||
|
||||
|
||||
class RSColorStream:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def as_video_stream_profile(self):
|
||||
return RSColorProfile(self.config)
|
||||
|
||||
|
||||
class RSProfile:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_stream(self, color_format):
|
||||
del color_format # unused
|
||||
return RSColorStream(self.config)
|
||||
|
||||
|
||||
class pipeline: # noqa: N801
|
||||
def __init__(self):
|
||||
self.started = False
|
||||
self.config = None
|
||||
|
||||
def start(self, config):
|
||||
self.started = True
|
||||
self.config = config
|
||||
return RSProfile(self.config)
|
||||
|
||||
def stop(self):
|
||||
if not self.started:
|
||||
raise RuntimeError("You need to start the camera before stop.")
|
||||
self.started = False
|
||||
self.config = None
|
||||
|
||||
def wait_for_frames(self, timeout_ms=50000):
|
||||
del timeout_ms # unused
|
||||
return RSFrames(self.config)
|
||||
|
||||
|
||||
class RSFrames:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_color_frame(self):
|
||||
return RSColorFrame(self.config)
|
||||
|
||||
def get_depth_frame(self):
|
||||
return RSDepthFrame(self.config)
|
||||
|
||||
|
||||
class RSColorFrame:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_data(self):
|
||||
data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8)
|
||||
# Create a difference between rgb and bgr
|
||||
data[:, :, 0] = 2
|
||||
return data
|
||||
|
||||
|
||||
class RSDepthFrame:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def get_data(self):
|
||||
return np.ones((self.config.height, self.config.width), dtype=np.uint16)
|
||||
|
||||
|
||||
class RSDevice:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_info(self, camera_info) -> str:
|
||||
del camera_info # unused
|
||||
# return fake serial number
|
||||
return "123456789"
|
||||
|
||||
|
||||
class context: # noqa: N801
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def query_devices(self):
|
||||
return [RSDevice()]
|
||||
|
||||
|
||||
class camera_info: # noqa: N801
|
||||
# fake name
|
||||
name = "Intel RealSense D435I"
|
||||
|
||||
def __init__(self, serial_number):
|
||||
del serial_number
|
||||
pass
|
||||
1
lerobot/common/mocks/motors/__init__.py
Normal file
1
lerobot/common/mocks/motors/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Mocks for motor modules
|
||||
107
lerobot/common/mocks/motors/mock_dynamixel_sdk.py
Normal file
107
lerobot/common/mocks/motors/mock_dynamixel_sdk.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
||||
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
||||
|
||||
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
|
||||
from the original classes and functions (e.g. return types might be None instead of boolean).
|
||||
"""
|
||||
|
||||
# from dynamixel_sdk import COMM_SUCCESS
|
||||
|
||||
DEFAULT_BAUDRATE = 9_600
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
return value
|
||||
|
||||
|
||||
def get_default_motor_values(motor_index):
|
||||
return {
|
||||
# Key (int) are from X_SERIES_CONTROL_TABLE
|
||||
7: motor_index, # ID
|
||||
8: DEFAULT_BAUDRATE, # Baud_rate
|
||||
10: 0, # Drive_Mode
|
||||
64: 0, # Torque_Enable
|
||||
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
|
||||
# For other joints, 2560 will be autocorrected to be in calibration range
|
||||
132: 2560, # Present_Position
|
||||
}
|
||||
|
||||
|
||||
class PortHandler:
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
# factory default baudrate
|
||||
self.baudrate = DEFAULT_BAUDRATE
|
||||
|
||||
def openPort(self): # noqa: N802
|
||||
return True
|
||||
|
||||
def closePort(self): # noqa: N802
|
||||
pass
|
||||
|
||||
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
|
||||
del timeout_ms # unused
|
||||
|
||||
def getBaudRate(self): # noqa: N802
|
||||
return self.baudrate
|
||||
|
||||
def setBaudRate(self, baudrate): # noqa: N802
|
||||
self.baudrate = baudrate
|
||||
|
||||
|
||||
class PacketHandler:
|
||||
def __init__(self, protocol_version):
|
||||
del protocol_version # unused
|
||||
# Use packet_handler.data to communicate across Read and Write
|
||||
self.data = {}
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if motor_index not in self.packet_handler.data:
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
||||
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
def addParam(self, index, data): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if index not in self.packet_handler.data:
|
||||
self.packet_handler.data[index] = get_default_motor_values(index)
|
||||
self.changeParam(index, data)
|
||||
|
||||
def txPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def changeParam(self, index, data): # noqa: N802
|
||||
self.packet_handler.data[index][self.address] = data
|
||||
125
lerobot/common/mocks/motors/mock_scservo_sdk.py
Normal file
125
lerobot/common/mocks/motors/mock_scservo_sdk.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration
|
||||
and testing code logic that requires hardware and devices (e.g. robot arms, cameras)
|
||||
|
||||
Warning: These mocked versions are minimalist. They do not exactly mock every behaviors
|
||||
from the original classes and functions (e.g. return types might be None instead of boolean).
|
||||
"""
|
||||
|
||||
# from dynamixel_sdk import COMM_SUCCESS
|
||||
|
||||
DEFAULT_BAUDRATE = 1_000_000
|
||||
COMM_SUCCESS = 0 # tx or rx packet communication success
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform
|
||||
# `convert_bytes_to_value`
|
||||
del bytes # unused
|
||||
return value
|
||||
|
||||
|
||||
def get_default_motor_values(motor_index):
|
||||
return {
|
||||
# Key (int) are from SCS_SERIES_CONTROL_TABLE
|
||||
5: motor_index, # ID
|
||||
6: DEFAULT_BAUDRATE, # Baud_rate
|
||||
10: 0, # Drive_Mode
|
||||
21: 32, # P_Coefficient
|
||||
22: 32, # D_Coefficient
|
||||
23: 0, # I_Coefficient
|
||||
40: 0, # Torque_Enable
|
||||
41: 254, # Acceleration
|
||||
31: -2047, # Offset
|
||||
33: 0, # Mode
|
||||
55: 1, # Lock
|
||||
# Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144
|
||||
# For other joints, 2560 will be autocorrected to be in calibration range
|
||||
56: 2560, # Present_Position
|
||||
58: 0, # Present_Speed
|
||||
69: 0, # Present_Current
|
||||
85: 150, # Maximum_Acceleration
|
||||
}
|
||||
|
||||
|
||||
class PortHandler:
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
# factory default baudrate
|
||||
self.baudrate = DEFAULT_BAUDRATE
|
||||
self.ser = SerialMock()
|
||||
|
||||
def openPort(self): # noqa: N802
|
||||
return True
|
||||
|
||||
def closePort(self): # noqa: N802
|
||||
pass
|
||||
|
||||
def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802
|
||||
del timeout_ms # unused
|
||||
|
||||
def getBaudRate(self): # noqa: N802
|
||||
return self.baudrate
|
||||
|
||||
def setBaudRate(self, baudrate): # noqa: N802
|
||||
self.baudrate = baudrate
|
||||
|
||||
|
||||
class PacketHandler:
|
||||
def __init__(self, protocol_version):
|
||||
del protocol_version # unused
|
||||
# Use packet_handler.data to communicate across Read and Write
|
||||
self.data = {}
|
||||
|
||||
|
||||
class GroupSyncRead:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if motor_index not in self.packet_handler.data:
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
||||
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def getData(self, index, address, bytes): # noqa: N802
|
||||
return self.packet_handler.data[index][address]
|
||||
|
||||
|
||||
class GroupSyncWrite:
|
||||
def __init__(self, port_handler, packet_handler, address, bytes):
|
||||
self.packet_handler = packet_handler
|
||||
self.address = address
|
||||
|
||||
def addParam(self, index, data): # noqa: N802
|
||||
if index not in self.packet_handler.data:
|
||||
self.packet_handler.data[index] = get_default_motor_values(index)
|
||||
self.changeParam(index, data)
|
||||
|
||||
def txPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
def changeParam(self, index, data): # noqa: N802
|
||||
self.packet_handler.data[index][self.address] = data
|
||||
|
||||
|
||||
class SerialMock:
|
||||
def reset_output_buffer(self):
|
||||
pass
|
||||
|
||||
def reset_input_buffer(self):
|
||||
pass
|
||||
@@ -25,6 +25,7 @@ from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -54,6 +55,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "pi0fast":
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -69,6 +74,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[pi0]"
|
||||
pip install -e ".[pi0]"
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
|
||||
136
lerobot/common/policies/pi0fast/configuration_pi0fast.py
Normal file
136
lerobot/common/policies/pi0fast/configuration_pi0fast.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@dataclass
|
||||
class PI0FASTConfig(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 10
|
||||
n_action_steps: int = 5
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32 # 32
|
||||
max_action_dim: int = 32 # 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
interpolate_like_pi: bool = False
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
max_decoding_steps: int = 256
|
||||
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
max_input_seq_len: int = 256 # 512
|
||||
|
||||
# Utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Frozen parameters
|
||||
freeze_vision_encoder: bool = True
|
||||
freeze_lm_head: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
checkpoint_path: str = None
|
||||
|
||||
padding_side: str = "right"
|
||||
|
||||
precision: str = "bfloat16"
|
||||
grad_clip_norm: float = 1
|
||||
|
||||
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
||||
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
||||
relaxed_action_decoding: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
973
lerobot/common/policies/pi0fast/modeling_pi0fast.py
Normal file
973
lerobot/common/policies/pi0fast/modeling_pi0fast.py
Normal file
@@ -0,0 +1,973 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
|
||||
|
||||
[Paper](https://arxiv.org/abs/2501.09747)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of training the pi0+FAST neural network with from scratch:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=pi0fast \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from scipy.fft import idct
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
|
||||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
PRECISION = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0FASTPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0FASTConfig
|
||||
name = "pi0fast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0FASTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model.generate_actions(batch)
|
||||
|
||||
actions = actions[:, : self.config.n_action_steps]
|
||||
|
||||
original_action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
def block_causal_update_causal_mask(
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
attn_implementation: str = "eager",
|
||||
dtype: torch.dtype = "float32",
|
||||
):
|
||||
"""
|
||||
Update the causal mask during training and generation. It can be customized to different attention masks.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
|
||||
if using_static_cache or isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
# Handle precomputed attention masks
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
return attention_mask
|
||||
|
||||
# Causal mask initialization
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
|
||||
# Standard causal masking (triu ensures tokens can only attend to past)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
|
||||
# Apply block causal mask
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.to(causal_mask.device).bool()
|
||||
cumsum = torch.cumsum(token_type_ids, dim=1)
|
||||
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
|
||||
# Combine causal_mask with block-wise attention mask
|
||||
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
|
||||
causal_mask = causal_mask[:, None, :, :]
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# Apply padding mask
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
# self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=None,
|
||||
labels=None,
|
||||
self=None,
|
||||
**kwargs,
|
||||
):
|
||||
# create block causal attention
|
||||
if cache_position[0] > 0 and input_ids.shape[1] > 0:
|
||||
input_tensor = input_ids[:, -1:]
|
||||
new_positions = (
|
||||
torch.ones(
|
||||
(position_ids.shape[0], input_ids.shape[1]),
|
||||
dtype=position_ids.dtype,
|
||||
device=position_ids.device,
|
||||
).cumsum(-1)
|
||||
+ position_ids[:, -1:]
|
||||
)
|
||||
position_ids = torch.cat([position_ids, new_positions], dim=-1)
|
||||
else:
|
||||
input_tensor = inputs_embeds
|
||||
attention_mask = block_causal_update_causal_mask(
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
input_tensor=input_tensor,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.dtype,
|
||||
attn_implementation=self.config.text_config._attn_implementation,
|
||||
)
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
class PI0FAST(nn.Module):
|
||||
def __init__(self, config: PI0FASTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# TODO: move tokenizers in Policy
|
||||
fast_tokenizer_path = "physical-intelligence/fast"
|
||||
pi0_paligemma_path = "google/paligemma-3b-pt-224"
|
||||
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
|
||||
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
|
||||
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
||||
self.fast_skip_tokens = self.config.fast_skip_tokens
|
||||
self.max_input_seq_len = self.config.max_input_seq_len
|
||||
self.action_horizon = self.config.chunk_size
|
||||
self.action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
precision = config.precision
|
||||
torch_precision = PRECISION.get(precision, torch.float32)
|
||||
self.pad_token_id = (
|
||||
self.paligemma_tokenizer.pad_token_id
|
||||
if hasattr(self.paligemma_tokenizer, "pad_token_id")
|
||||
else self.paligemma_tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": precision,
|
||||
"vocab_size": 257152,
|
||||
"_attn_implementation": "eager",
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_pytorch_tanh",
|
||||
"torch_dtype": precision,
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||
|
||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||
)
|
||||
# change important stuff in bf16
|
||||
params_to_change_dtype = [
|
||||
"language_model",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.pi0_paligemma.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch_precision)
|
||||
self.set_requires_grad()
|
||||
self.image_keys = self.config.image_features.keys()
|
||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||
self.padding_side = self.config.padding_side
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.pi0_paligemma.vision_tower.eval()
|
||||
for params in self.pi0_paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
if self.config.freeze_lm_head:
|
||||
for name, params in self.pi0_paligemma.named_parameters():
|
||||
if "embed_tokens" in name: # lm heads and embedding layer are tied
|
||||
params.requires_grad = False
|
||||
|
||||
def embed_tokens(self, tokens: torch.Tensor):
|
||||
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Preprocess LeRobot batch into Pi0 inputs"""
|
||||
images = []
|
||||
img_masks = []
|
||||
present_img_keys = [key for key in self.image_keys if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
num_empty_cameras = 0
|
||||
for key in self.image_keys:
|
||||
if key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(
|
||||
img,
|
||||
*self.config.resize_imgs_with_padding,
|
||||
pad_value=0,
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
else:
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
continue
|
||||
img = torch.ones_like(img) * -1
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
num_empty_cameras += 1
|
||||
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
|
||||
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
return out
|
||||
|
||||
def fast_tokenizer_wrapper(self, actions_norm):
|
||||
"""
|
||||
A wrapper for self.fast_tokenizer that ensures batch processing,
|
||||
conversion to PyTorch tensors, and returns a dictionary without padding.
|
||||
"""
|
||||
batch_tokens = self.fast_tokenizer(actions_norm)
|
||||
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
|
||||
|
||||
return fast_out
|
||||
|
||||
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
|
||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||
# Compute cumulative sum mask
|
||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||
# Suffix block (everything after prefix_len)
|
||||
suffix_mask = cumsum_mask > prefix_len
|
||||
token_type_ids = suffix_mask
|
||||
return token_type_ids
|
||||
|
||||
def create_input_tokens(self, state, lang_text, actions=None):
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||
discretized = torch.bucketize(state, bins) - 1
|
||||
discretized = discretized[:, :32]
|
||||
|
||||
prefix_texts = []
|
||||
state_text = []
|
||||
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||
cleaned = txt.lower().strip().replace("_", " ")
|
||||
state_str = " ".join(str(val.item()) for val in disc)
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||
state_text.append(f"State: {state_str};\n")
|
||||
|
||||
prefix_out = self.paligemma_tokenizer(
|
||||
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
|
||||
)
|
||||
prefix_ids = prefix_out["input_ids"].to(device)
|
||||
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||
|
||||
if actions is not None:
|
||||
actions_norm = self.normalize_actions(actions)
|
||||
actions_pad = F.pad(
|
||||
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
|
||||
)[:, :, : self.config.max_action_dim]
|
||||
fast_out = self.fast_tokenizer_wrapper(
|
||||
actions_pad.cpu(),
|
||||
)
|
||||
act_ids = fast_out["input_ids"]
|
||||
act_mask = fast_out["attention_mask"].to(device)
|
||||
|
||||
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||
# Replace action with 0 to pad tokens
|
||||
act_ids = torch.where(
|
||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||
self.pad_token_id,
|
||||
act_ids,
|
||||
)
|
||||
|
||||
eos_token = torch.tensor(
|
||||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||
).expand(bsize, -1)
|
||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
||||
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
||||
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
act_mask = act_mask.to(device)
|
||||
else:
|
||||
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||
|
||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||
|
||||
# Use tokenizer pad function
|
||||
padded_output = self.paligemma_tokenizer.pad(
|
||||
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
|
||||
)
|
||||
padded_mask = padded_output["attention_mask"]
|
||||
|
||||
# define tensor of padding lengths
|
||||
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
||||
|
||||
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
||||
|
||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||
padded_output["attention_mask"] = att_mask
|
||||
# loss is computed not on prefix, and not on padding
|
||||
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
|
||||
padded_output["token_type_ids"] = token_type_ids
|
||||
return padded_output
|
||||
|
||||
def shift_padding_side(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
ar_mask: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
loss_mask: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
token_type_ids: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> tuple[torch.Tensor]:
|
||||
if padding_side not in ["right", "left"]:
|
||||
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
|
||||
|
||||
new_tokens = torch.empty_like(tokens)
|
||||
new_ar_masks = torch.empty_like(ar_mask)
|
||||
new_padding_mask = torch.empty_like(padding_mask)
|
||||
new_loss_mask = torch.empty_like(loss_mask)
|
||||
new_targets = torch.empty_like(targets)
|
||||
new_token_type_ids = torch.empty_like(token_type_ids)
|
||||
batch_size = tokens.shape[0]
|
||||
for i in range(batch_size):
|
||||
padding_indices = torch.where(padding_mask[i] == 0)[0]
|
||||
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
|
||||
if padding_side == "left":
|
||||
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
|
||||
else:
|
||||
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
|
||||
new_tokens[i] = tokens[i].index_select(0, new_indices)
|
||||
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
|
||||
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
|
||||
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
|
||||
new_targets[i] = targets[i].index_select(0, new_indices)
|
||||
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
|
||||
|
||||
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]):
|
||||
device = batch[OBS_ROBOT].device
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(
|
||||
state=batch[OBS_ROBOT],
|
||||
lang_text=batch["task"],
|
||||
actions=batch[ACTION],
|
||||
)
|
||||
|
||||
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
past_seen_tokens = 0
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
|
||||
pad_masks = block_causal_update_causal_mask(
|
||||
attention_mask=pad_masks,
|
||||
past_key_values=None,
|
||||
cache_position=cache_position,
|
||||
input_tensor=embs,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.pi0_paligemma.dtype,
|
||||
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
|
||||
)
|
||||
outputs = self.pi0_paligemma.forward(
|
||||
input_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=False,
|
||||
labels=None,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
|
||||
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Shift left for next-step prediction
|
||||
logits = logits[:, :-1, :]
|
||||
targets = targets[:, 1:].to(device) # Shift targets
|
||||
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
|
||||
|
||||
# Compute per-token loss
|
||||
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
|
||||
|
||||
# Apply loss mask
|
||||
token_loss = token_loss * loss_mask.reshape(-1)
|
||||
|
||||
# Compute final loss
|
||||
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
|
||||
|
||||
# Return loss dictionary
|
||||
loss_dict = {"ce_loss": loss.item(), "loss": loss}
|
||||
return loss_dict
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
relaxed_decoding: bool = True,
|
||||
) -> np.array:
|
||||
"""
|
||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||
"""
|
||||
self.time_horizon = (
|
||||
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
||||
)
|
||||
self.action_dim = (
|
||||
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
|
||||
)
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
||||
if relaxed_decoding:
|
||||
# Expected sequence length
|
||||
expected_seq_len = self.time_horizon * self.action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
# Apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
||||
# Apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens (torch.Tensor): The input tensor of tokenized outputs.
|
||||
action_horizon (int): The number of timesteps for actions.
|
||||
action_dim (int): The dimensionality of each action.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
|
||||
"""
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
raw_action_tokens = [
|
||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
] # something like this should be robust #looks good
|
||||
action_tokens = [
|
||||
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
||||
]
|
||||
# returns the tensor of decoded actions per sample in a list
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(),
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
relaxed_decoding=self.config.relaxed_action_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
|
||||
return torch.stack(
|
||||
decoded_actions,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]):
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
|
||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side="left",
|
||||
)
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
output_tokens = self.pi0_paligemma.generate(
|
||||
input_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=self.config.use_cache,
|
||||
max_new_tokens=self.config.max_decoding_steps,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
||||
return actions
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.pi0_paligemma.get_image_features(image)
|
||||
|
||||
def embed_inputs(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
pad_mask,
|
||||
ar_mask,
|
||||
loss_mask,
|
||||
token_type_ids,
|
||||
padding_side: str = "right",
|
||||
):
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
# images are a list of same size
|
||||
# vectorizing everything!
|
||||
device = images[0].device
|
||||
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
|
||||
all_images = torch.stack(images, dim=1).to(device)
|
||||
b, n, c, h, w = all_images.shape
|
||||
all_images = all_images.view(b * n, c, h, w)
|
||||
embedded = self.embed_image(all_images).to(device)
|
||||
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
|
||||
m = b_n // b # Compute the number of images per sample dynamically
|
||||
|
||||
# Reshape dynamically
|
||||
embedded = embedded.view(b, m, p, image_embedding_dim)
|
||||
tokens_embs = self.embed_tokens(tokens.to(device))
|
||||
|
||||
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
|
||||
num_img_emb = embedded.shape[2]
|
||||
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
|
||||
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
image_target_tokens = (
|
||||
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
|
||||
).reshape(b, -1)
|
||||
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
|
||||
|
||||
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
|
||||
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
|
||||
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
|
||||
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
|
||||
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
|
||||
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
|
||||
|
||||
# Shift pad tokens to the left (.generate()) or right (.train())
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
|
||||
)
|
||||
|
||||
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
|
||||
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
if interpolate_like_pi:
|
||||
img = (img * 255.0).to(dtype=torch.uint8)
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
original_device = img.device
|
||||
img = img.to(device="cpu").numpy()
|
||||
imgs = []
|
||||
for sub_img in img:
|
||||
sub_img = Image.fromarray(sub_img)
|
||||
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
|
||||
resized_img = torch.from_numpy(np.array(resized_img))
|
||||
imgs.append(resized_img)
|
||||
img = torch.stack(imgs, dim=0)
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
|
||||
else:
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
@@ -122,7 +122,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
|
||||
@@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
||||
connected to the computer.
|
||||
"""
|
||||
if mock:
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
@@ -100,7 +100,7 @@ def save_images_from_cameras(
|
||||
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
||||
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import lerobot.common.mocks.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -253,7 +253,7 @@ class IntelRealSenseCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import lerobot.common.mocks.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -287,7 +287,7 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
import lerobot.common.mocks.cameras.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
@@ -375,7 +375,7 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import lerobot.common.mocks.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def _find_cameras(
|
||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||
) -> list[int | str]:
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import lerobot.common.mocks.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -269,7 +269,7 @@ class OpenCVCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import lerobot.common.mocks.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -286,7 +286,7 @@ class OpenCVCamera:
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import lerobot.common.mocks.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -398,7 +398,7 @@ class OpenCVCamera:
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import lerobot.common.mocks.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig):
|
||||
fps: int | None = None
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
display_data: bool = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("record")
|
||||
@@ -82,7 +82,7 @@ class RecordControlConfig(ControlConfig):
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
display_data: bool = False
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
@@ -116,6 +116,11 @@ class ReplayControlConfig(ControlConfig):
|
||||
@dataclass
|
||||
class RemoteRobotConfig(ControlConfig):
|
||||
log_interval: int = 100
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp)
|
||||
viewer_ip: str | None = None
|
||||
viewer_port: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -24,7 +24,7 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
@@ -174,13 +174,13 @@ def warmup_record(
|
||||
events,
|
||||
enable_teleoperation,
|
||||
warmup_time_s,
|
||||
display_cameras,
|
||||
display_data,
|
||||
fps,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=warmup_time_s,
|
||||
display_cameras=display_cameras,
|
||||
display_data=display_data,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=enable_teleoperation,
|
||||
@@ -192,7 +192,7 @@ def record_episode(
|
||||
dataset,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
display_data,
|
||||
policy,
|
||||
fps,
|
||||
single_task,
|
||||
@@ -200,7 +200,7 @@ def record_episode(
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
display_data=display_data,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
@@ -215,7 +215,7 @@ def control_loop(
|
||||
robot,
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy: PreTrainedPolicy = None,
|
||||
@@ -264,11 +264,15 @@ def control_loop(
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
||||
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
||||
for k, v in action.items():
|
||||
for i, vv in enumerate(v):
|
||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
rr.log(key, rr.Image(observation[key].numpy()), static=True)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
@@ -297,15 +301,11 @@ def reset_environment(robot, events, reset_time_s, fps):
|
||||
)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
def stop_recording(robot, listener, display_data):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
if not is_headless() and listener is not None:
|
||||
listener.stop()
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
||||
@@ -332,7 +332,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -356,7 +356,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -646,7 +646,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -691,7 +691,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -757,7 +757,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -793,7 +793,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import lerobot.common.mocks.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -337,7 +337,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -664,7 +664,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -702,7 +702,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -782,7 +782,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -818,7 +818,7 @@ class FeetechMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import lerobot.common.mocks.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
|
||||
@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091",
|
||||
port="/dev/tty.usbmodem58760429101",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
port="/dev/tty.usbmodem58760435821",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
|
||||
@@ -94,7 +94,7 @@ class MetricsTracker:
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
):
|
||||
self.__dict__.update({k: None for k in self.__keys__})
|
||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
|
||||
@@ -135,15 +135,19 @@ python lerobot/scripts/control_robot.py \
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
|
||||
import rerun as rr
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlConfig,
|
||||
ControlPipelineConfig,
|
||||
RecordControlConfig,
|
||||
RemoteRobotConfig,
|
||||
@@ -153,6 +157,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
@@ -232,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
|
||||
control_time_s=cfg.teleop_time_s,
|
||||
fps=cfg.fps,
|
||||
teleoperate=True,
|
||||
display_cameras=cfg.display_cameras,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -280,7 +285,7 @@ def record(
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", cfg.play_sounds)
|
||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
||||
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps)
|
||||
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
@@ -296,7 +301,7 @@ def record(
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
display_data=cfg.display_data,
|
||||
policy=policy,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
@@ -326,7 +331,7 @@ def record(
|
||||
break
|
||||
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
stop_recording(robot, listener, cfg.display_data)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
@@ -363,6 +368,40 @@ def replay(
|
||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||
|
||||
|
||||
def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop.
|
||||
|
||||
Args:
|
||||
control_config: Configuration determining data display and robot type.
|
||||
session_name: Rerun session name. Defaults to "lerobot_control_loop".
|
||||
|
||||
Raises:
|
||||
ValueError: If viewer IP is missing for non-remote configurations with display enabled.
|
||||
"""
|
||||
if (control_config.display_data and not is_headless()) or (
|
||||
control_config.display_data and isinstance(control_config, RemoteRobotConfig)
|
||||
):
|
||||
# Configure Rerun flush batch size default to 8KB if not set
|
||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||
|
||||
# Initialize Rerun based on configuration
|
||||
rr.init(session_name)
|
||||
if isinstance(control_config, RemoteRobotConfig):
|
||||
viewer_ip = control_config.viewer_ip
|
||||
viewer_port = control_config.viewer_port
|
||||
if not viewer_ip or not viewer_port:
|
||||
raise ValueError(
|
||||
"Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data."
|
||||
)
|
||||
logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}")
|
||||
rr.connect_tcp(f"{viewer_ip}:{viewer_port}")
|
||||
else:
|
||||
# Get memory limit for rerun viewer parameters
|
||||
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def control_robot(cfg: ControlPipelineConfig):
|
||||
init_logging()
|
||||
@@ -370,17 +409,22 @@ def control_robot(cfg: ControlPipelineConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
# TODO(Steven): Blueprint for fixed window size
|
||||
|
||||
if isinstance(cfg.control, CalibrateControlConfig):
|
||||
calibrate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, TeleoperateControlConfig):
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop")
|
||||
teleoperate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RecordControlConfig):
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record")
|
||||
record(robot, cfg.control)
|
||||
elif isinstance(cfg.control, ReplayControlConfig):
|
||||
replay(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
||||
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote")
|
||||
run_lekiwi(cfg.robot)
|
||||
|
||||
if robot.is_connected:
|
||||
|
||||
@@ -66,7 +66,7 @@ from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
@@ -124,7 +124,6 @@ def rollout(
|
||||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
@@ -145,6 +144,7 @@ def rollout(
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done):
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
@@ -155,6 +155,10 @@ def rollout(
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
|
||||
53
lerobot/scripts/server/async_inference.proto
Normal file
53
lerobot/scripts/server/async_inference.proto
Normal file
@@ -0,0 +1,53 @@
|
||||
// fmt: off
|
||||
// flake8: noqa
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package async_inference;
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc StreamActions(Empty) returns (stream Action);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Action {
|
||||
// sent by remote Policy, to Robot
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
46
lerobot/scripts/server/async_inference_pb2.py
Normal file
46
lerobot/scripts/server/async_inference_pb2.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: async_inference.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'async_inference.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"N\n\x06\x41\x63tion\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xd9\x01\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12\x42\n\rStreamActions\x12\x16.async_inference.Empty\x1a\x17.async_inference.Action0\x01\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=216
|
||||
_globals['_TRANSFERSTATE']._serialized_end=312
|
||||
_globals['_OBSERVATION']._serialized_start=42
|
||||
_globals['_OBSERVATION']._serialized_end=125
|
||||
_globals['_ACTION']._serialized_start=127
|
||||
_globals['_ACTION']._serialized_end=205
|
||||
_globals['_EMPTY']._serialized_start=207
|
||||
_globals['_EMPTY']._serialized_end=214
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=315
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=532
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
193
lerobot/scripts/server/async_inference_pb2_grpc.py
Normal file
193
lerobot/scripts/server/async_inference_pb2_grpc.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# fmt: off
|
||||
# flake8: noqa
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import async_inference_pb2 as async__inference__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.StreamActions = channel.unary_stream(
|
||||
'/async_inference.AsyncInference/StreamActions',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Action.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def StreamActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'StreamActions': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamActions,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Action.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'async_inference.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
async__inference__pb2.Observation.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def StreamActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/StreamActions',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Action.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
199
lerobot/scripts/server/policy_server.py
Normal file
199
lerobot/scripts/server/policy_server.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import itertools
|
||||
import pickle # nosec
|
||||
import time
|
||||
from concurrent import futures
|
||||
from queue import Queue
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.scripts.server.robot_client import TimedAction, TimedObservation, environment_dt
|
||||
|
||||
inference_latency = 1 / 3
|
||||
idle_wait = 0.1
|
||||
|
||||
|
||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
def __init__(self):
|
||||
# TODO: Add device specification for policy inference at init
|
||||
self.device = "mps"
|
||||
start = time.time()
|
||||
self.policy = ACTPolicy.from_pretrained("fracapuano/act_so100_test")
|
||||
self.policy.to(self.device)
|
||||
end = time.time()
|
||||
print(f"Time taken to put policy on {self.device}: {end - start} seconds")
|
||||
|
||||
# Initialize dataset action generator
|
||||
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
|
||||
|
||||
self._setup_server()
|
||||
|
||||
self.actions_per_chunk = 20
|
||||
self.actions_overlap = 10
|
||||
|
||||
def _setup_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
self._setup_server()
|
||||
print("Client connected and ready")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
# client_id = context.peer()
|
||||
# print(f"Receiving observations from {client_id}")
|
||||
|
||||
for observation in request_iterator:
|
||||
timed_observation = pickle.loads(observation.data) # nosec
|
||||
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(timed_observation)
|
||||
print("Received observation no: ", timed_observation.get_timestep())
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def StreamActions(self, request, context): # noqa: N802
|
||||
"""Stream actions to the robot client"""
|
||||
# client_id = context.peer()
|
||||
# print(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
obs = self.observation_queue.get()
|
||||
print("Running inference for timestep: ", obs.get_timestep())
|
||||
|
||||
if obs:
|
||||
yield self._predict_action_chunk(obs)
|
||||
|
||||
else:
|
||||
print("No observation in queue yet!")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action based on the observation"""
|
||||
self.policy.eval()
|
||||
|
||||
observation = {}
|
||||
for k, v in observation_t.get_observation().items():
|
||||
if "image" in k:
|
||||
observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device)
|
||||
else:
|
||||
observation[k] = v.unsqueeze(0).to(self.device)
|
||||
|
||||
# Remove batch dimension
|
||||
action_tensor = self.policy.select_action(observation).squeeze(0)
|
||||
|
||||
if action_tensor.dim() == 1:
|
||||
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
|
||||
action_tensor = action_tensor.cpu().repeat(self.actions_per_chunk, 1)
|
||||
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
|
||||
action_bytes = pickle.dumps(action_chunk) # nosec
|
||||
# Create and return the Action message
|
||||
action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes)
|
||||
|
||||
time.sleep(inference_latency) # slow action generation, emulates inference time (ACT is very fast)
|
||||
|
||||
return action
|
||||
|
||||
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
|
||||
"""Stream chunks of actions from a prerecorded dataset.
|
||||
|
||||
Returns:
|
||||
Generator that yields chunks of actions from the dataset
|
||||
"""
|
||||
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch")
|
||||
|
||||
# 1. Select the action column only, where you will find tensors with 6 elements
|
||||
actions = dataset["action"]
|
||||
action_indices = torch.arange(len(actions))
|
||||
|
||||
# 2. Chunk the iterable of tensors into chunks with 10 elements each
|
||||
# sending only first element for debugging
|
||||
indices_chunks = action_indices.unfold(
|
||||
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
|
||||
)
|
||||
|
||||
for idx_chunk in indices_chunks:
|
||||
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
|
||||
|
||||
def _read_action_chunk(self, observation: Optional[TimedObservation] = None):
|
||||
"""Dummy function for predicting action chunk given observation.
|
||||
|
||||
Instead of computing actions on-the-fly, this method streams
|
||||
actions from a prerecorded dataset.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
if not observation:
|
||||
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
|
||||
transfer_state = 0
|
||||
else:
|
||||
transfer_state = observation.transfer_state
|
||||
|
||||
# Get chunk of actions from the generator
|
||||
actions_chunk = next(self.action_generator)
|
||||
|
||||
# Return a list of TimedActions, with timestamps starting from the observation timestamp
|
||||
action_data = self._time_action_chunk(
|
||||
observation.get_timestamp(), actions_chunk, observation.get_timestep()
|
||||
)
|
||||
action_bytes = pickle.dumps(action_data) # nosec
|
||||
|
||||
# Create and return the Action message
|
||||
action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes)
|
||||
|
||||
time.sleep(inference_latency) # slow action generation, emulates inference time
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(PolicyServer(), server)
|
||||
server.add_insecure_port("[::]:50051")
|
||||
server.start()
|
||||
print("PolicyServer started on port 50051")
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(86400) # Sleep for a day, or until interrupted
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
print("Server stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
357
lerobot/scripts/server/robot_client.py
Normal file
357
lerobot/scripts/server/robot_client.py
Normal file
@@ -0,0 +1,357 @@
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty, Queue
|
||||
from typing import Any, Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
|
||||
environment_dt = 1 / 30
|
||||
idle_wait = 0.1
|
||||
|
||||
|
||||
class TimedData:
|
||||
def __init__(self, timestamp: float, data: Any, timestep: int):
|
||||
"""Initialize a TimedData object.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
"""
|
||||
self.timestamp = timestamp
|
||||
self.data = data
|
||||
self.timestep = timestep
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
class TimedAction(TimedData):
|
||||
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
|
||||
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
|
||||
|
||||
def get_action(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class TimedObservation(TimedData):
|
||||
def __init__(
|
||||
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
|
||||
):
|
||||
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
|
||||
self.transfer_state = transfer_state
|
||||
|
||||
def get_observation(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class RobotClient:
|
||||
def __init__(
|
||||
self,
|
||||
# cfg: RobotConfig,
|
||||
server_address="localhost:50051",
|
||||
use_robot=True,
|
||||
):
|
||||
self.channel = grpc.insecure_channel(server_address)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
|
||||
self.running = False
|
||||
self.first_observation_sent = False
|
||||
self.latest_action = 0
|
||||
self.action_chunk_size = 20
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.start_barrier = threading.Barrier(3)
|
||||
|
||||
# Create a lock for robot access
|
||||
self.robot_lock = threading.Lock()
|
||||
|
||||
self.use_robot = use_robot
|
||||
if self.use_robot:
|
||||
self.robot = make_robot("so100")
|
||||
self.robot.connect()
|
||||
|
||||
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
||||
print("Robot connected")
|
||||
|
||||
self.robot_reading = True
|
||||
|
||||
def timestamps(self):
|
||||
"""Get the timestamps of the actions in the queue"""
|
||||
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
self.stub.Ready(async_inference_pb2.Empty())
|
||||
print("Connected to policy server")
|
||||
|
||||
self.running = True
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.running = False
|
||||
if self.use_robot and hasattr(self, "robot"):
|
||||
self.robot.disconnect()
|
||||
self.channel.close()
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
print("Client not running")
|
||||
return False
|
||||
|
||||
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
|
||||
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
|
||||
|
||||
try:
|
||||
_ = self.stub.SendObservations(iter([observation]))
|
||||
if transfer_state == async_inference_pb2.TRANSFER_BEGIN:
|
||||
self.first_observation_sent = True
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Error sending observation: {e}")
|
||||
return False
|
||||
|
||||
def _validate_action(self, action: TimedAction):
|
||||
"""Received actions are keps only when they have been produced for now or later, never before"""
|
||||
return not action.get_timestamp() < self.latest_action
|
||||
|
||||
def _validate_action_chunk(self, actions: list[TimedAction]):
|
||||
assert len(actions) == self.action_chunk_size, (
|
||||
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
|
||||
)
|
||||
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
|
||||
|
||||
return True
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
print("Queue size: ", self.action_queue.qsize())
|
||||
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
|
||||
|
||||
def _clear_queue(self):
|
||||
"""Clear the existing queue"""
|
||||
while not self.action_queue.empty():
|
||||
try:
|
||||
self.action_queue.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def _fill_action_queue(self, actions: list[TimedAction]):
|
||||
"""Fill the action queue with incoming valid actions"""
|
||||
for action in actions:
|
||||
if self._validate_action(action):
|
||||
self.action_queue.put(action)
|
||||
|
||||
def _update_action_queue(self, actions: list[TimedAction]):
|
||||
"""Aggregate incoming actions into the action queue.
|
||||
Raises NotImplementedError as this is not implemented yet.
|
||||
|
||||
Args:
|
||||
actions: List of TimedAction instances to queue
|
||||
"""
|
||||
# TODO: Implement this
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
|
||||
"""Clear the existing queue and fill it with new actions.
|
||||
This is a higher-level function that combines clearing and filling operations.
|
||||
|
||||
Args:
|
||||
actions: List of TimedAction instances to queue
|
||||
"""
|
||||
print("*** Current latest action: ", self.latest_action, "***")
|
||||
print("\t**** Current queue content ****: ")
|
||||
self._inspect_action_queue()
|
||||
|
||||
print("\t*** Incoming actions ****: ")
|
||||
print([a.get_timestep() for a in actions])
|
||||
|
||||
self._clear_queue()
|
||||
self._fill_action_queue(actions)
|
||||
|
||||
print("\t*** Queue after clearing and filling ****: ")
|
||||
self._inspect_action_queue()
|
||||
|
||||
def receive_actions(self):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
print("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
|
||||
# strategy for queue composition is specified in the method
|
||||
self._clear_and_fill_action_queue(timed_actions)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Error receiving actions: {e}")
|
||||
time.sleep(idle_wait) # Avoid tight loop on error
|
||||
|
||||
def _get_next_action(self) -> Optional[TimedAction]:
|
||||
"""Get the next action from the queue"""
|
||||
try:
|
||||
action = self.action_queue.get_nowait()
|
||||
return action
|
||||
|
||||
except Empty:
|
||||
return None
|
||||
|
||||
def execute_actions(self):
|
||||
"""Continuously execute actions from the queue"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
print("Action execution thread starting")
|
||||
|
||||
while self.running:
|
||||
# Get the next action from the queue
|
||||
time.sleep(environment_dt)
|
||||
timed_action = self._get_next_action()
|
||||
|
||||
if timed_action is not None:
|
||||
# self.latest_action = timed_action.get_timestep()
|
||||
self.latest_action = timed_action.get_timestamp()
|
||||
|
||||
# Convert action to tensor and send to robot
|
||||
if self.use_robot:
|
||||
# Acquire lock before accessing the robot
|
||||
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
|
||||
try:
|
||||
self.robot.send_action(timed_action.get_action())
|
||||
finally:
|
||||
# Always release the lock in a finally block to ensure it's released
|
||||
self.robot_lock.release()
|
||||
else:
|
||||
print("Could not acquire robot lock for action execution, retrying next cycle")
|
||||
|
||||
else:
|
||||
# No action available, wait and retry fetching from queue
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def stream_observations(self, get_observation_fn):
|
||||
"""Continuously stream observations to the server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
print("Observation streaming thread starting")
|
||||
|
||||
first_observation = True
|
||||
while self.running:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
time.sleep(environment_dt)
|
||||
observation = get_observation_fn()
|
||||
|
||||
# Skip if observation is None (couldn't acquire lock)
|
||||
if observation is None:
|
||||
continue
|
||||
|
||||
# Set appropriate transfer state
|
||||
if first_observation:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
first_observation = False
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
self.send_observation(observation, state)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in observation sender: {e}")
|
||||
time.sleep(idle_wait)
|
||||
|
||||
|
||||
def async_client():
|
||||
# Example of how to use the RobotClient
|
||||
client = RobotClient()
|
||||
|
||||
if client.start():
|
||||
# Function to generate mock observations
|
||||
def get_observation():
|
||||
# Create a counter attribute if it doesn't exist
|
||||
if not hasattr(get_observation, "counter"):
|
||||
get_observation.counter = 0
|
||||
|
||||
# Acquire lock before accessing the robot
|
||||
observation_content = None
|
||||
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
|
||||
try:
|
||||
observation_content = client.robot.capture_observation()
|
||||
finally:
|
||||
# Always release the lock in a finally block to ensure it's released
|
||||
client.robot_lock.release()
|
||||
else:
|
||||
print("Could not acquire robot lock for observation capture, skipping this cycle")
|
||||
return None # Return None to indicate no observation was captured
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), observation=observation_content, timestep=get_observation.counter
|
||||
)
|
||||
|
||||
# Increment counter for next call
|
||||
get_observation.counter += 1
|
||||
|
||||
return observation
|
||||
|
||||
print("Starting all threads...")
|
||||
|
||||
# Create and start observation sender thread
|
||||
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,))
|
||||
obs_thread.daemon = True
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
||||
action_receiver_thread.daemon = True
|
||||
|
||||
# Create action execution thread
|
||||
action_execution_thread = threading.Thread(target=client.execute_actions)
|
||||
action_execution_thread.daemon = True
|
||||
|
||||
# Start all threads
|
||||
obs_thread.start()
|
||||
action_receiver_thread.start()
|
||||
action_execution_thread.start()
|
||||
|
||||
try:
|
||||
# Main thread just keeps everything alive
|
||||
while client.running:
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
print("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async_client()
|
||||
@@ -133,7 +133,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
|
||||
@@ -60,16 +60,16 @@ dependencies = [
|
||||
"jsonlines>=4.0.0",
|
||||
"numba>=0.59.0",
|
||||
"omegaconf>=2.3.0",
|
||||
"opencv-python>=4.9.0",
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"av>=12.0.5,<13.0.0",
|
||||
"av>=12.0.5",
|
||||
"pymunk>=6.6.0",
|
||||
"pynput>=1.7.7",
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
|
||||
@@ -172,8 +172,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
push_to_hub=False,
|
||||
# TODO(rcadene, aliberts): test video=True
|
||||
video=False,
|
||||
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
@@ -226,7 +225,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
@@ -273,7 +272,7 @@ def test_resume_record(tmp_path, request, robot_type, mock):
|
||||
episode_time_s=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_episodes=1,
|
||||
)
|
||||
@@ -330,7 +329,7 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
@@ -380,7 +379,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
@@ -433,7 +432,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user