forked from tangger/lerobot
Compare commits
174 Commits
user/azoui
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dcd850feab | ||
|
|
1ce368503d | ||
|
|
fb075a709d | ||
|
|
3424644ecd | ||
|
|
c37936f2c9 | ||
|
|
c5382a450c | ||
|
|
2f7339b410 | ||
|
|
9e5f254db0 | ||
|
|
8122721f6d | ||
|
|
5c352ae558 | ||
|
|
9386892f8e | ||
|
|
267a837a2c | ||
|
|
28b595c651 | ||
|
|
9fd4c21d4d | ||
|
|
02e1ed0bfb | ||
|
|
e18274bc9a | ||
|
|
68c271ad25 | ||
|
|
a3ada81816 | ||
|
|
203315d378 | ||
|
|
78c640b6d8 | ||
|
|
d5a87f67cf | ||
|
|
8bcf41761d | ||
|
|
1efaf02df9 | ||
|
|
cf58890bb0 | ||
|
|
7c2c67fc3c | ||
|
|
70130b9841 | ||
|
|
6167886472 | ||
|
|
f9fb9d4594 | ||
|
|
d86d29fe21 | ||
|
|
f83d215e7a | ||
|
|
7361a11a4d | ||
|
|
0cce2fe0fa | ||
|
|
88d26ae976 | ||
|
|
3a2308d86f | ||
|
|
fdd04efdb7 | ||
|
|
ff18be18ad | ||
|
|
427720426b | ||
|
|
66693965c0 | ||
|
|
334cf8143e | ||
|
|
5b49601072 | ||
|
|
0185a0b6fd | ||
|
|
70d418935d | ||
|
|
eb44a06a9b | ||
|
|
8eb3c1510c | ||
|
|
4d5ecb082e | ||
|
|
6e687e2910 | ||
|
|
eb710647bf | ||
|
|
176557d770 | ||
|
|
3beab33fac | ||
|
|
c0ba4b4954 | ||
|
|
8fb373aeb2 | ||
|
|
5a0ee06651 | ||
|
|
05a237ce10 | ||
|
|
88cc2b8fc8 | ||
|
|
b69132c79d | ||
|
|
db897a1619 | ||
|
|
0b5b62c8fb | ||
|
|
056f79d358 | ||
|
|
114ec644d0 | ||
|
|
26ee8b6ae5 | ||
|
|
38e8864284 | ||
|
|
80d566eb56 | ||
|
|
bb5a95889f | ||
|
|
0ea27704f6 | ||
|
|
2abbd60a0d | ||
|
|
1c8daf11fd | ||
|
|
cdcf346061 | ||
|
|
42f95e827d | ||
|
|
618ed00d45 | ||
|
|
50d8db481e | ||
|
|
e4a5971ffd | ||
|
|
36f9ccd851 | ||
|
|
787aee0e60 | ||
|
|
0341a38fdd | ||
|
|
ffbed4a141 | ||
|
|
03fe0f054b | ||
|
|
fd74c194b6 | ||
|
|
0959694bab | ||
|
|
7b01e16439 | ||
|
|
66816fd871 | ||
|
|
599326508f | ||
|
|
2f04d0d2b9 | ||
|
|
e002c5ec56 | ||
|
|
3dfb37e976 | ||
|
|
b6a2200983 | ||
|
|
85fe8a3f4e | ||
|
|
bb69cb3c8c | ||
|
|
ae51c19b3c | ||
|
|
9ea79f8a76 | ||
|
|
1d4ec50a58 | ||
|
|
4c73891575 | ||
|
|
d3b84ecd6f | ||
|
|
e1d55c7a44 | ||
|
|
85242cac67 | ||
|
|
0d88a5ee09 | ||
|
|
62e237bdee | ||
|
|
c85f88fb62 | ||
|
|
a90f4872f2 | ||
|
|
a16ea283f5 | ||
|
|
8209a6dfb7 | ||
|
|
b5fbeb7401 | ||
|
|
2ac25b02e2 | ||
|
|
39fe4b1301 | ||
|
|
140e30e386 | ||
|
|
ddcc0415e4 | ||
|
|
5195f40fd3 | ||
|
|
98c6557869 | ||
|
|
ee820859d3 | ||
|
|
5d6879d93a | ||
|
|
fae47d58d3 | ||
|
|
3a07301365 | ||
|
|
f1af97dc9c | ||
|
|
f2266101df | ||
|
|
9784d8a47f | ||
|
|
af769abd8d | ||
|
|
12c13e320e | ||
|
|
273fa2e6e1 | ||
|
|
d143043037 | ||
|
|
ca45c34ad5 | ||
|
|
b1679050de | ||
|
|
d2c41b35db | ||
|
|
bc7b6d3daf | ||
|
|
2516101cba | ||
|
|
aebea08a99 | ||
|
|
03616db82c | ||
|
|
93c4fc198f | ||
|
|
8cd44ae163 | ||
|
|
2ae657f568 | ||
|
|
508f5d1407 | ||
|
|
c8b1132846 | ||
|
|
ef777993cd | ||
|
|
760d60ad4b | ||
|
|
875c0271b7 | ||
|
|
57344bfde5 | ||
|
|
46827fb002 | ||
|
|
2fd78879f6 | ||
|
|
e8449e9630 | ||
|
|
a0e2be8b92 | ||
|
|
181727c0fe | ||
|
|
d1d6ffd23c | ||
|
|
e5801f467f | ||
|
|
c6ca9523de | ||
|
|
642e3a3274 | ||
|
|
146148c48c | ||
|
|
8f15835daa | ||
|
|
022bd65125 | ||
|
|
63d8c96514 | ||
|
|
4624a836e5 | ||
|
|
ad7eea132d | ||
|
|
22a1899ff4 | ||
|
|
17a3a31b5f | ||
|
|
1a8b99e360 | ||
|
|
6db2154f28 | ||
|
|
be3adda95f | ||
|
|
9d48d236c1 | ||
|
|
b57d6a7776 | ||
|
|
d1f76cba8e | ||
|
|
d78cef1fee | ||
|
|
30a808c0ae | ||
|
|
4a7f85a6ec | ||
|
|
b43ece8934 | ||
|
|
c10c5a0e64 | ||
|
|
a8db91c40e | ||
|
|
0f5f7ac780 | ||
|
|
768e36660d | ||
|
|
790d6740ba | ||
|
|
5322417c03 | ||
|
|
4041f57943 | ||
|
|
2c86fea78a | ||
|
|
437fc29e12 | ||
|
|
aee86b4b18 | ||
|
|
1c873df5c0 | ||
|
|
145fe4cd17 | ||
|
|
e004247ed4 |
@@ -36,8 +36,8 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.2
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.31.1
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
@@ -48,7 +48,7 @@ repos:
|
||||
- id: pyupgrade
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.10
|
||||
rev: v0.11.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -57,12 +57,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.0
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.4.1
|
||||
rev: v1.5.2
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
|
||||
17
README.md
17
README.md
@@ -98,18 +98,25 @@ conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, if you don't have `ffmpeg` in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
conda install ffmpeg
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install --no-binary=av -e .
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
@@ -118,7 +125,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[aloha, pusht]"
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
|
||||
@@ -17,12 +17,21 @@
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
|
||||
# see https://rerun.io/docs/howto/visualization/limit-ram
|
||||
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||
|
||||
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
|
||||
rr.init("lerobot_capture_camera_feed")
|
||||
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
|
||||
|
||||
now = dt.datetime.now()
|
||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||
if not capture_dir.exists():
|
||||
@@ -39,24 +48,21 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||
|
||||
frame_index = 0
|
||||
while True:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < duration:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
print("Error: Could not read frame.")
|
||||
break
|
||||
|
||||
cv2.imshow("Video Stream", frame)
|
||||
rr.log("video/stream", rr.Image(frame.numpy()), static=True)
|
||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||
frame_index += 1
|
||||
|
||||
# Break the loop on 'q' key press
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
# Release the capture and destroy all windows
|
||||
# Release the capture
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -86,5 +92,11 @@ if __name__ == "__main__":
|
||||
default=720,
|
||||
help="Height of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Duration in seconds for which the video stream should be captured.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
display_and_save_video_stream(**vars(args))
|
||||
|
||||
@@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install ffmpeg build dependencies. See:
|
||||
|
||||
@@ -57,9 +57,15 @@ conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
#### 5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:.
|
||||
@@ -491,6 +497,9 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so100 \
|
||||
|
||||
@@ -67,9 +67,15 @@ conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
#### 5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## C. Install LeRobot on laptop
|
||||
@@ -108,9 +114,15 @@ conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
#### 5. Install LeRobot with dependencies for the feetech motors:
|
||||
#### 5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
#### 6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:.
|
||||
@@ -393,6 +405,10 @@ python lerobot/scripts/control_robot.py \
|
||||
```
|
||||
|
||||
# F. Teleoperate
|
||||
|
||||
> [!TIP]
|
||||
> If you're using a Mac, you might need to give Terminal permission to access your keyboard. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
|
||||
|
||||
To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
@@ -408,6 +424,8 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.fps=30
|
||||
```
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port`
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||
| ---------- | ------------------ | ---------------------- |
|
||||
|
||||
@@ -31,9 +31,15 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the feetech motors:
|
||||
5. Install ffmpeg in your environment:
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[feetech]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
6. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## Configure the motors
|
||||
@@ -212,6 +218,9 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=moss \
|
||||
|
||||
@@ -122,7 +122,7 @@ print(dataset.features[camera_key]["shape"])
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
@@ -146,6 +146,6 @@ dataloader = torch.utils.data.DataLoader(
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
@@ -18,7 +18,7 @@ training outputs directory. In the latter case, you might want to run examples/3
|
||||
|
||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[pusht]"`
|
||||
pip install -e ".[pusht]"
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly
|
||||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Initialize/load a configuration for the following steps using.
|
||||
- Instantiates a dataset.
|
||||
@@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
```
|
||||
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
@@ -50,7 +50,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
||||
|
||||
## Specifying values from the CLI
|
||||
|
||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
Let's say that we want to train [Diffusion Policy](../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
@@ -60,10 +60,10 @@ python lerobot/scripts/train.py \
|
||||
|
||||
Let's break this down:
|
||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../lerobot/common/envs/configs.py)
|
||||
|
||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
Let's see another example. Let's say you've been training [ACT](../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
@@ -74,7 +74,7 @@ python lerobot/scripts/train.py \
|
||||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||
|
||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
Looking at the [`AlohaEnv`](../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
|
||||
@@ -33,7 +33,7 @@ First, install the additional dependencies required for robots built with dynami
|
||||
|
||||
Using `pip`:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[dynamixel]"
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
Using `poetry`:
|
||||
@@ -55,6 +55,9 @@ Finally, connect both arms to your computer via USB. Note that the USB doesn't p
|
||||
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
||||
|
||||
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
@@ -827,11 +830,6 @@ It contains:
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
||||
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
||||
- or, install [Homebrew](https://brew.sh) and run `brew install ffmpeg` (it should be compiled with `libsvtav1`),
|
||||
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
|
||||
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:
|
||||
|
||||
@@ -43,14 +43,19 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
6. Install LeRobot with stretch dependencies:
|
||||
6. When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[stretch]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
7. Install LeRobot with stretch dependencies:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[stretch]"
|
||||
```
|
||||
|
||||
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
|
||||
|
||||
7. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
||||
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
|
||||
```bash
|
||||
stretch_system_check.py
|
||||
```
|
||||
@@ -97,6 +102,8 @@ This is equivalent to running `stretch_robot_home.py`
|
||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||
|
||||
Now try out teleoperation (see above documentation to learn about the gamepad controls):
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=stretch \
|
||||
|
||||
@@ -30,9 +30,14 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||
5. When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
cd ~/lerobot && pip install --no-binary=av -e ".[dynamixel, intelrealsense]"
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
@@ -43,6 +48,9 @@ Teleoperation consists in manually operating the leader arms to move the followe
|
||||
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
|
||||
|
||||
By running the following code, you can start your first **SAFE** teleoperation:
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=aloha \
|
||||
|
||||
@@ -1058,7 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
|
||||
@@ -243,7 +243,7 @@ def load_episodes_stats(local_dir: Path) -> dict:
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return {ep_idx: stats for ep_idx in episodes}
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
|
||||
@@ -508,7 +508,7 @@ def convert_dataset(
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
|
||||
@@ -171,7 +171,6 @@ class VideoRecordConfig:
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
delta_action: float | None = None
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
@@ -191,7 +190,6 @@ class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
delta_action: float = 0.1
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
@@ -203,8 +201,9 @@ class EnvWrapperConfig:
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float = 0.8
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -95,38 +99,40 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
return policy_features
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
|
||||
first_type = type(env.envs[0]) # Get type of first env
|
||||
return all(type(e) is first_type for e in env.envs) # Fast type check
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
if not (
|
||||
hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")
|
||||
):
|
||||
warnings.warn(
|
||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not are_all_envs_same_type(env):
|
||||
warnings.warn(
|
||||
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
def add_envs_task(
|
||||
env: gym.vector.VectorEnv, observation: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
elif hasattr(env.envs[0], "task"):
|
||||
observation["task"] = env.call("task")
|
||||
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
|
||||
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||
observation["task"] = ["" for _ in range(num_envs)]
|
||||
return observation
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -55,6 +56,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "pi0fast":
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -78,6 +83,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "hilserl_classifier":
|
||||
return ClassifierConfig(**kwargs)
|
||||
else:
|
||||
|
||||
@@ -24,7 +24,7 @@ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
pip install --no-binary=av -e ".[pi0]"
|
||||
pip install -e ".[pi0]"
|
||||
```
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
|
||||
136
lerobot/common/policies/pi0fast/configuration_pi0fast.py
Normal file
136
lerobot/common/policies/pi0fast/configuration_pi0fast.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@dataclass
|
||||
class PI0FASTConfig(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 10
|
||||
n_action_steps: int = 5
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32 # 32
|
||||
max_action_dim: int = 32 # 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (224, 224)
|
||||
interpolate_like_pi: bool = False
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 1024
|
||||
|
||||
# Decoding
|
||||
max_decoding_steps: int = 256
|
||||
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
max_input_seq_len: int = 256 # 512
|
||||
|
||||
# Utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Frozen parameters
|
||||
freeze_vision_encoder: bool = True
|
||||
freeze_lm_head: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
checkpoint_path: str = None
|
||||
|
||||
padding_side: str = "right"
|
||||
|
||||
precision: str = "bfloat16"
|
||||
grad_clip_norm: float = 1
|
||||
|
||||
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
|
||||
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
|
||||
relaxed_action_decoding: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
973
lerobot/common/policies/pi0fast/modeling_pi0fast.py
Normal file
973
lerobot/common/policies/pi0fast/modeling_pi0fast.py
Normal file
@@ -0,0 +1,973 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models
|
||||
|
||||
[Paper](https://arxiv.org/abs/2501.09747)
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of training the pi0+FAST neural network with from scratch:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=pi0fast \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of using the pi0 pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from scipy.fft import idct
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration
|
||||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
PRECISION = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class PI0FASTPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = PI0FASTConfig
|
||||
name = "pi0fast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0FASTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model.generate_actions(batch)
|
||||
|
||||
actions = actions[:, : self.config.n_action_steps]
|
||||
|
||||
original_action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
def block_causal_update_causal_mask(
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
attn_implementation: str = "eager",
|
||||
dtype: torch.dtype = "float32",
|
||||
):
|
||||
"""
|
||||
Update the causal mask during training and generation. It can be customized to different attention masks.
|
||||
"""
|
||||
if attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
|
||||
if using_static_cache or isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
# Handle precomputed attention masks
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
return attention_mask
|
||||
|
||||
# Causal mask initialization
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
|
||||
# Standard causal masking (triu ensures tokens can only attend to past)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
|
||||
# Apply block causal mask
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.to(causal_mask.device).bool()
|
||||
cumsum = torch.cumsum(token_type_ids, dim=1)
|
||||
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
|
||||
# Combine causal_mask with block-wise attention mask
|
||||
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask)
|
||||
causal_mask = causal_mask[:, None, :, :]
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
else:
|
||||
# Apply past cache position constraint
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# Apply padding mask
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
# self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=None,
|
||||
labels=None,
|
||||
self=None,
|
||||
**kwargs,
|
||||
):
|
||||
# create block causal attention
|
||||
if cache_position[0] > 0 and input_ids.shape[1] > 0:
|
||||
input_tensor = input_ids[:, -1:]
|
||||
new_positions = (
|
||||
torch.ones(
|
||||
(position_ids.shape[0], input_ids.shape[1]),
|
||||
dtype=position_ids.dtype,
|
||||
device=position_ids.device,
|
||||
).cumsum(-1)
|
||||
+ position_ids[:, -1:]
|
||||
)
|
||||
position_ids = torch.cat([position_ids, new_positions], dim=-1)
|
||||
else:
|
||||
input_tensor = inputs_embeds
|
||||
attention_mask = block_causal_update_causal_mask(
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
input_tensor=input_tensor,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.dtype,
|
||||
attn_implementation=self.config.text_config._attn_implementation,
|
||||
)
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
class PI0FAST(nn.Module):
|
||||
def __init__(self, config: PI0FASTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# TODO: move tokenizers in Policy
|
||||
fast_tokenizer_path = "physical-intelligence/fast"
|
||||
pi0_paligemma_path = "google/paligemma-3b-pt-224"
|
||||
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
|
||||
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
|
||||
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
||||
self.fast_skip_tokens = self.config.fast_skip_tokens
|
||||
self.max_input_seq_len = self.config.max_input_seq_len
|
||||
self.action_horizon = self.config.chunk_size
|
||||
self.action_dim = self.config.action_feature.shape[
|
||||
0
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
precision = config.precision
|
||||
torch_precision = PRECISION.get(precision, torch.float32)
|
||||
self.pad_token_id = (
|
||||
self.paligemma_tokenizer.pad_token_id
|
||||
if hasattr(self.paligemma_tokenizer, "pad_token_id")
|
||||
else self.paligemma_tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"](
|
||||
transformers_version="4.48.1",
|
||||
_vocab_size=257152,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
hidden_size=2048,
|
||||
image_token_index=257152,
|
||||
model_type="paligemma",
|
||||
pad_token_id=0,
|
||||
projection_dim=2048,
|
||||
text_config={
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 16384,
|
||||
"model_type": "gemma",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 18,
|
||||
"num_image_tokens": 256,
|
||||
"num_key_value_heads": 1,
|
||||
"torch_dtype": precision,
|
||||
"vocab_size": 257152,
|
||||
"_attn_implementation": "eager",
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"num_image_tokens": 256,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
"projector_hidden_act": "gelu_pytorch_tanh",
|
||||
"torch_dtype": precision,
|
||||
"vision_use_head": False,
|
||||
},
|
||||
)
|
||||
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
|
||||
|
||||
self.pi0_paligemma.prepare_inputs_for_generation = partial(
|
||||
prepare_inputs_for_generation, self=self.pi0_paligemma
|
||||
)
|
||||
# change important stuff in bf16
|
||||
params_to_change_dtype = [
|
||||
"language_model",
|
||||
"vision_tower",
|
||||
"multi_modal",
|
||||
]
|
||||
for name, param in self.pi0_paligemma.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch_precision)
|
||||
self.set_requires_grad()
|
||||
self.image_keys = self.config.image_features.keys()
|
||||
self.ignore_index = self.pi0_paligemma.config.ignore_index
|
||||
self.padding_side = self.config.padding_side
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.config.freeze_vision_encoder:
|
||||
self.pi0_paligemma.vision_tower.eval()
|
||||
for params in self.pi0_paligemma.vision_tower.parameters():
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
if self.config.freeze_lm_head:
|
||||
for name, params in self.pi0_paligemma.named_parameters():
|
||||
if "embed_tokens" in name: # lm heads and embedding layer are tied
|
||||
params.requires_grad = False
|
||||
|
||||
def embed_tokens(self, tokens: torch.Tensor):
|
||||
return self.pi0_paligemma.language_model.model.embed_tokens(tokens)
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Preprocess LeRobot batch into Pi0 inputs"""
|
||||
images = []
|
||||
img_masks = []
|
||||
present_img_keys = [key for key in self.image_keys if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
|
||||
# Preprocess image features present in the batch
|
||||
num_empty_cameras = 0
|
||||
for key in self.image_keys:
|
||||
if key in present_img_keys:
|
||||
img = batch[key]
|
||||
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(
|
||||
img,
|
||||
*self.config.resize_imgs_with_padding,
|
||||
pad_value=0,
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
else:
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
continue
|
||||
img = torch.ones_like(img) * -1
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
num_empty_cameras += 1
|
||||
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
mins = actions.amin(dim=(1, 2), keepdim=True) # [0]
|
||||
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0]
|
||||
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
return out
|
||||
|
||||
def fast_tokenizer_wrapper(self, actions_norm):
|
||||
"""
|
||||
A wrapper for self.fast_tokenizer that ensures batch processing,
|
||||
conversion to PyTorch tensors, and returns a dictionary without padding.
|
||||
"""
|
||||
batch_tokens = self.fast_tokenizer(actions_norm)
|
||||
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt")
|
||||
|
||||
return fast_out
|
||||
|
||||
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor:
|
||||
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool)
|
||||
# Compute cumulative sum mask
|
||||
cumsum_mask = (padded_mask != 0).cumsum(dim=1)
|
||||
# Suffix block (everything after prefix_len)
|
||||
suffix_mask = cumsum_mask > prefix_len
|
||||
token_type_ids = suffix_mask
|
||||
return token_type_ids
|
||||
|
||||
def create_input_tokens(self, state, lang_text, actions=None):
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
|
||||
discretized = torch.bucketize(state, bins) - 1
|
||||
discretized = discretized[:, :32]
|
||||
|
||||
prefix_texts = []
|
||||
state_text = []
|
||||
for txt, disc in zip(lang_text, discretized, strict=False):
|
||||
cleaned = txt.lower().strip().replace("_", " ")
|
||||
state_str = " ".join(str(val.item()) for val in disc)
|
||||
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
|
||||
state_text.append(f"State: {state_str};\n")
|
||||
|
||||
prefix_out = self.paligemma_tokenizer(
|
||||
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
|
||||
)
|
||||
prefix_ids = prefix_out["input_ids"].to(device)
|
||||
prefix_mask = prefix_out["attention_mask"].to(device)
|
||||
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
|
||||
|
||||
if actions is not None:
|
||||
actions_norm = self.normalize_actions(actions)
|
||||
actions_pad = F.pad(
|
||||
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
|
||||
)[:, :, : self.config.max_action_dim]
|
||||
fast_out = self.fast_tokenizer_wrapper(
|
||||
actions_pad.cpu(),
|
||||
)
|
||||
act_ids = fast_out["input_ids"]
|
||||
act_mask = fast_out["attention_mask"].to(device)
|
||||
|
||||
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
|
||||
# Replace action with 0 to pad tokens
|
||||
act_ids = torch.where(
|
||||
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
|
||||
self.pad_token_id,
|
||||
act_ids,
|
||||
)
|
||||
|
||||
eos_token = torch.tensor(
|
||||
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
|
||||
).expand(bsize, -1)
|
||||
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
|
||||
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
|
||||
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
|
||||
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
|
||||
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
|
||||
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
|
||||
act_mask = act_mask.to(device)
|
||||
else:
|
||||
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
|
||||
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
|
||||
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
|
||||
|
||||
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
|
||||
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
|
||||
|
||||
# Use tokenizer pad function
|
||||
padded_output = self.paligemma_tokenizer.pad(
|
||||
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
|
||||
)
|
||||
padded_mask = padded_output["attention_mask"]
|
||||
|
||||
# define tensor of padding lengths
|
||||
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
|
||||
|
||||
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
|
||||
|
||||
padded_output["padded_mask"] = padded_output.pop("attention_mask")
|
||||
padded_output["attention_mask"] = att_mask
|
||||
# loss is computed not on prefix, and not on padding
|
||||
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
|
||||
padded_output["token_type_ids"] = token_type_ids
|
||||
return padded_output
|
||||
|
||||
def shift_padding_side(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
ar_mask: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
loss_mask: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
token_type_ids: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> tuple[torch.Tensor]:
|
||||
if padding_side not in ["right", "left"]:
|
||||
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids
|
||||
|
||||
new_tokens = torch.empty_like(tokens)
|
||||
new_ar_masks = torch.empty_like(ar_mask)
|
||||
new_padding_mask = torch.empty_like(padding_mask)
|
||||
new_loss_mask = torch.empty_like(loss_mask)
|
||||
new_targets = torch.empty_like(targets)
|
||||
new_token_type_ids = torch.empty_like(token_type_ids)
|
||||
batch_size = tokens.shape[0]
|
||||
for i in range(batch_size):
|
||||
padding_indices = torch.where(padding_mask[i] == 0)[0]
|
||||
non_padding_indices = torch.where(padding_mask[i] == 1)[0]
|
||||
if padding_side == "left":
|
||||
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0)
|
||||
else:
|
||||
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0)
|
||||
new_tokens[i] = tokens[i].index_select(0, new_indices)
|
||||
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices)
|
||||
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices)
|
||||
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices)
|
||||
new_targets[i] = targets[i].index_select(0, new_indices)
|
||||
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices)
|
||||
|
||||
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]):
|
||||
device = batch[OBS_ROBOT].device
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(
|
||||
state=batch[OBS_ROBOT],
|
||||
lang_text=batch["task"],
|
||||
actions=batch[ACTION],
|
||||
)
|
||||
|
||||
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
past_seen_tokens = 0
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
|
||||
pad_masks = block_causal_update_causal_mask(
|
||||
attention_mask=pad_masks,
|
||||
past_key_values=None,
|
||||
cache_position=cache_position,
|
||||
input_tensor=embs,
|
||||
token_type_ids=token_type_ids,
|
||||
dtype=self.pi0_paligemma.dtype,
|
||||
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
|
||||
)
|
||||
outputs = self.pi0_paligemma.forward(
|
||||
input_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=False,
|
||||
labels=None,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
|
||||
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Shift left for next-step prediction
|
||||
logits = logits[:, :-1, :]
|
||||
targets = targets[:, 1:].to(device) # Shift targets
|
||||
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape
|
||||
|
||||
# Compute per-token loss
|
||||
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))
|
||||
|
||||
# Apply loss mask
|
||||
token_loss = token_loss * loss_mask.reshape(-1)
|
||||
|
||||
# Compute final loss
|
||||
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)
|
||||
|
||||
# Return loss dictionary
|
||||
loss_dict = {"ce_loss": loss.item(), "loss": loss}
|
||||
return loss_dict
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
relaxed_decoding: bool = True,
|
||||
) -> np.array:
|
||||
"""
|
||||
Adapt original decoding in FAST to always return actions instead of zeros.
|
||||
"""
|
||||
self.time_horizon = (
|
||||
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon
|
||||
)
|
||||
self.action_dim = (
|
||||
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim
|
||||
)
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token
|
||||
if relaxed_decoding:
|
||||
# Expected sequence length
|
||||
expected_seq_len = self.time_horizon * self.action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
# Apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right
|
||||
# Apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens (torch.Tensor): The input tensor of tokenized outputs.
|
||||
action_horizon (int): The number of timesteps for actions.
|
||||
action_dim (int): The dimensionality of each action.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim).
|
||||
"""
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
raw_action_tokens = [
|
||||
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
] # something like this should be robust #looks good
|
||||
action_tokens = [
|
||||
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens
|
||||
]
|
||||
# returns the tensor of decoded actions per sample in a list
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(),
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
relaxed_decoding=self.config.relaxed_action_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
|
||||
return torch.stack(
|
||||
decoded_actions,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]):
|
||||
# TODO: keep like this or move to the policy .forward
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
|
||||
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None)
|
||||
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs(
|
||||
images,
|
||||
img_masks,
|
||||
padded_outs["input_ids"],
|
||||
padded_outs["padded_mask"],
|
||||
padded_outs["attention_mask"],
|
||||
padded_outs["loss_mask"],
|
||||
padded_outs["token_type_ids"],
|
||||
padding_side="left",
|
||||
)
|
||||
token_type_ids = token_type_ids.to(dtype=torch.int64)
|
||||
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
output_tokens = self.pi0_paligemma.generate(
|
||||
input_ids=None,
|
||||
attention_mask=pad_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=embs,
|
||||
use_cache=self.config.use_cache,
|
||||
max_new_tokens=self.config.max_decoding_steps,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim)
|
||||
return actions
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
return self.pi0_paligemma.get_image_features(image)
|
||||
|
||||
def embed_inputs(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
pad_mask,
|
||||
ar_mask,
|
||||
loss_mask,
|
||||
token_type_ids,
|
||||
padding_side: str = "right",
|
||||
):
|
||||
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
|
||||
# images are a list of same size
|
||||
# vectorizing everything!
|
||||
device = images[0].device
|
||||
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config
|
||||
all_images = torch.stack(images, dim=1).to(device)
|
||||
b, n, c, h, w = all_images.shape
|
||||
all_images = all_images.view(b * n, c, h, w)
|
||||
embedded = self.embed_image(all_images).to(device)
|
||||
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions
|
||||
m = b_n // b # Compute the number of images per sample dynamically
|
||||
|
||||
# Reshape dynamically
|
||||
embedded = embedded.view(b, m, p, image_embedding_dim)
|
||||
tokens_embs = self.embed_tokens(tokens.to(device))
|
||||
|
||||
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device)
|
||||
num_img_emb = embedded.shape[2]
|
||||
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1)
|
||||
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
image_target_tokens = (
|
||||
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id
|
||||
).reshape(b, -1)
|
||||
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1)
|
||||
|
||||
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D)
|
||||
|
||||
embs = torch.cat([embedded, tokens_embs], dim=1).to(device)
|
||||
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1)
|
||||
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1)
|
||||
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1)
|
||||
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1)
|
||||
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1)
|
||||
|
||||
# Shift pad tokens to the left (.generate()) or right (.train())
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side(
|
||||
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side
|
||||
)
|
||||
|
||||
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets)
|
||||
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
if interpolate_like_pi:
|
||||
img = (img * 255.0).to(dtype=torch.uint8)
|
||||
img = img.permute(0, 2, 3, 1)
|
||||
original_device = img.device
|
||||
img = img.to(device="cpu").numpy()
|
||||
imgs = []
|
||||
for sub_img in img:
|
||||
sub_img = Image.fromarray(sub_img)
|
||||
resized_img = sub_img.resize((resized_width, resized_height), resample=2)
|
||||
resized_img = torch.from_numpy(np.array(resized_img))
|
||||
imgs.append(resized_img)
|
||||
img = torch.stack(imgs, dim=0)
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0
|
||||
else:
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
@@ -51,8 +51,8 @@ class ActorNetworkConfig:
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: int = -5
|
||||
log_std_max: int = 2
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ class SACConfig(PreTrainedConfig):
|
||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
@@ -147,6 +148,7 @@ class SACConfig(PreTrainedConfig):
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
num_discrete_actions: int | None = None
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
|
||||
@@ -22,13 +22,12 @@ from dataclasses import asdict
|
||||
from typing import Callable, List, Literal, Optional, Tuple
|
||||
|
||||
import einops
|
||||
from importlib_metadata import distribution
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -53,158 +52,46 @@ class SACPolicy(
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
self.shared_encoder = config.shared_encoder
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=target_critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.grasp_critic = None
|
||||
self.grasp_critic_target = None
|
||||
|
||||
if config.num_discrete_actions is not None:
|
||||
# Create grasp critic
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=1.0,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# Create target grasp critic
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=encoder_critic,
|
||||
input_dim=encoder_critic.output_dim,
|
||||
output_dim=config.num_discrete_actions,
|
||||
softmax_temperature=1.0,
|
||||
**asdict(config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
self.grasp_critic = torch.compile(self.grasp_critic)
|
||||
self.grasp_critic_target = torch.compile(self.grasp_critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(continuous_action_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
|
||||
temperature_init = config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": self.actor.parameters_to_optimize,
|
||||
"critic": self.critic_ensemble.parameters_to_optimize,
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters_to_optimize
|
||||
optim_params["grasp_critic"] = self.grasp_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
# We cached the encoder output to avoid recomputing it
|
||||
observations_features = None
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
_, discrete_action_distribution = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = discrete_action_distribution.sample().unsqueeze(-1).float()
|
||||
discrete_action_value = self.grasp_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
@@ -433,18 +320,19 @@ class SACPolicy(
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
gripper_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_grasp_qs, next_grasp_distribution = self.grasp_critic_forward(
|
||||
next_grasp_qs = self.grasp_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_grasp_action = next_grasp_distribution.sample().unsqueeze(-1)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_grasp_qs, _ = self.grasp_critic_forward(
|
||||
target_next_grasp_qs = self.grasp_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
@@ -455,14 +343,14 @@ class SACPolicy(
|
||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_gripper = rewards
|
||||
if gripper_penalties is not None:
|
||||
rewards_gripper = rewards + gripper_penalties
|
||||
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs, _ = self.grasp_critic_forward(
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
@@ -502,109 +390,265 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
if self.config.dataset_stats:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = Normalize(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional grasp critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_grasp_critics()
|
||||
|
||||
def _init_grasp_critics(self):
|
||||
"""Build discrete grasp critic ensemble and target networks."""
|
||||
self.grasp_critic = GraspCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.grasp_critic_network_kwargs),
|
||||
)
|
||||
self.grasp_critic_target = GraspCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.grasp_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the grasp critic
|
||||
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=self.shared_encoder,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
if self.config.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.config.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self):
|
||||
"""Set up temperature parameter and initial log_alpha."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
super(SACObservationEncoder, self).__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_features):
|
||||
self.camera_number = config.camera_number
|
||||
def _init_image_layers(self) -> None:
|
||||
self.image_keys = [k for k in self.config.input_features if k.startswith("observation.image")]
|
||||
self.has_images = bool(self.image_keys)
|
||||
if not self.has_images:
|
||||
return
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
if self.config.vision_encoder_name:
|
||||
self.image_encoder = PretrainedImageEncoder(self.config)
|
||||
else:
|
||||
self.image_encoder = DefaultImageEncoder(self.config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if self.config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_encoder)
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
|
||||
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
|
||||
with torch.no_grad():
|
||||
_, channels, height, width = self.image_encoder(dummy).shape
|
||||
|
||||
if "observation.state" in config.input_features:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
self.spatial_embeddings = nn.ModuleDict()
|
||||
self.post_encoders = nn.ModuleDict()
|
||||
|
||||
for key in self.image_keys:
|
||||
name = key.replace(".", "_")
|
||||
self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
|
||||
height=height,
|
||||
width=width,
|
||||
channel=channels,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
)
|
||||
self.post_encoders[name] = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
in_features=channels * self.config.image_embedding_pooling_dim,
|
||||
out_features=self.config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.LayerNorm(normalized_shape=self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_features:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_features["observation.environment_state"].shape[0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
def _init_state_layers(self) -> None:
|
||||
self.has_env = "observation.environment_state" in self.config.input_features
|
||||
self.has_state = "observation.state" in self.config.input_features
|
||||
if self.has_env:
|
||||
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||
self.env_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if self.has_state:
|
||||
dim = self.config.input_features["observation.state"].shape[0]
|
||||
self.state_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
def _compute_output_dim(self) -> None:
|
||||
out = 0
|
||||
if self.has_images:
|
||||
out += len(self.image_keys) * self.config.latent_dim
|
||||
if self.has_env:
|
||||
out += self.config.latent_dim
|
||||
if self.has_state:
|
||||
out += self.config.latent_dim
|
||||
self._out_dim = out
|
||||
|
||||
def forward(
|
||||
self, obs_dict: dict[str, Tensor], vision_encoder_cache: torch.Tensor | None = None
|
||||
self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
if self.has_state:
|
||||
parts.append(self.state_encoder(obs["observation.state"]))
|
||||
if parts:
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
raise ValueError(
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
the resulting features.
|
||||
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
|
||||
reused across policy components (actor, critic, grasp_critic), avoiding redundant forward passes.
|
||||
|
||||
Performance impact:
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward() with normalize=False
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
|
||||
vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
|
||||
|
||||
if vision_encoder_cache is not None:
|
||||
feat.append(vision_encoder_cache)
|
||||
|
||||
if "observation.environment_state" in self.config.input_features:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_features:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
|
||||
# [N*B, C, H, W]
|
||||
if normalize:
|
||||
batch = self.input_normalization(batch)
|
||||
if len(self.all_image_keys) > 0:
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
embeddings_image = torch.cat(embeddings_chunks, dim=-1)
|
||||
return embeddings_image
|
||||
return None
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
return dict(zip(self.image_keys, chunks, strict=False))
|
||||
|
||||
def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
|
||||
"""Encode image features from cached observations.
|
||||
|
||||
This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
|
||||
It also supports detaching the encoded features if specified.
|
||||
|
||||
Args:
|
||||
cache (dict[str, Tensor]): The cached image features.
|
||||
detach (bool): Usually when the encoder is shared between actor and critics,
|
||||
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
|
||||
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
|
||||
|
||||
Returns:
|
||||
Tensor: The encoded image features.
|
||||
"""
|
||||
feats = []
|
||||
for k, feat in cache.items():
|
||||
safe_key = k.replace(".", "_")
|
||||
x = self.spatial_embeddings[safe_key](feat)
|
||||
x = self.post_encoders[safe_key](x)
|
||||
if detach:
|
||||
x = x.detach()
|
||||
feats.append(x)
|
||||
return torch.cat(feats, dim=-1)
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
return self._out_dim
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@@ -740,12 +784,6 @@ class CriticEnsemble(nn.Module):
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
@@ -761,7 +799,7 @@ class CriticEnsemble(nn.Module):
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, observation_features)
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
@@ -787,7 +825,6 @@ class GraspCritic(nn.Module):
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
softmax_temperature: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
@@ -809,20 +846,14 @@ class GraspCritic(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
self.parameters_to_optimize += list(self.net.parameters())
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
q_values = self.output_layer(self.net(obs_enc))
|
||||
distribution = Categorical(logits=q_values / self.softmax_temperature)
|
||||
return q_values, distribution
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
@@ -846,12 +877,8 @@ class Policy(nn.Module):
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
self.encoder_is_shared = encoder_is_shared
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
@@ -865,7 +892,6 @@ class Policy(nn.Module):
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
@@ -874,15 +900,15 @@ class Policy(nn.Module):
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = self.encoder(observations, vision_encoder_cache=observation_features)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# We detach the encoder if it is shared to avoid backprop through it
|
||||
# This is important to avoid the encoder to be updated through the policy
|
||||
obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
@@ -891,29 +917,20 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
std = torch.exp(log_std) # Match JAX "exp"
|
||||
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
# Build transformed distribution
|
||||
dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
# Sample actions (reparameterized)
|
||||
actions = dist.rsample()
|
||||
|
||||
# Compute log_probs
|
||||
log_probs = dist.log_prob(actions)
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
@@ -961,21 +978,16 @@ class DefaultImageEncoder(nn.Module):
|
||||
nn.ReLU(),
|
||||
)
|
||||
# Get first image key from input features
|
||||
image_key = next(key for key in config.input_features.keys() if key.startswith("observation.image")) # noqa: SIM118
|
||||
dummy_batch = torch.zeros(1, *config.input_features[image_key].shape)
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
x = self.image_enc_layers(x)
|
||||
return x
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
@@ -983,18 +995,12 @@ class PretrainedImageEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
@@ -1005,19 +1011,10 @@ class PretrainedImageEncoder(nn.Module):
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
enc_feat = self.image_enc_layers(x).last_hidden_state
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
@@ -1030,6 +1027,112 @@ class Identity(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, C, H, W] where B is batch size,
|
||||
C is number of channels, H is height, and W is width
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] where F is the number of features
|
||||
"""
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RescaleFromTanh(Transform):
|
||||
def __init__(self, low: float = -1, high: float = 1):
|
||||
super().__init__()
|
||||
|
||||
self.low = low
|
||||
|
||||
self.high = high
|
||||
|
||||
def _call(self, x):
|
||||
# Rescale from (-1, 1) to (low, high)
|
||||
|
||||
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
|
||||
|
||||
def _inverse(self, y):
|
||||
# Rescale from (low, high) back to (-1, 1)
|
||||
|
||||
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
|
||||
|
||||
def log_abs_det_jacobian(self, x, y):
|
||||
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
|
||||
|
||||
scale = 0.5 * (self.high - self.low)
|
||||
|
||||
return torch.sum(torch.log(scale), dim=-1)
|
||||
|
||||
|
||||
class TanhMultivariateNormalDiag(TransformedDistribution):
|
||||
def __init__(self, loc, scale_diag, low=None, high=None):
|
||||
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
|
||||
|
||||
transforms = [TanhTransform(cache_size=1)]
|
||||
|
||||
if low is not None and high is not None:
|
||||
low = torch.as_tensor(low)
|
||||
|
||||
high = torch.as_tensor(high)
|
||||
|
||||
transforms.insert(0, RescaleFromTanh(low, high))
|
||||
|
||||
super().__init__(base_dist, transforms)
|
||||
|
||||
def mode(self):
|
||||
# Mode is mean of base distribution, passed through transforms
|
||||
|
||||
x = self.base_dist.mean
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def stddev(self):
|
||||
std = self.base_dist.stddev
|
||||
|
||||
x = std
|
||||
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
@@ -1040,90 +1143,3 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# # Benchmark the CriticEnsemble performance
|
||||
# import time
|
||||
|
||||
# # Configuration
|
||||
# num_critics = 10
|
||||
# batch_size = 32
|
||||
# action_dim = 7
|
||||
# obs_dim = 64
|
||||
# hidden_dims = [256, 256]
|
||||
# num_iterations = 100
|
||||
|
||||
# print("Creating test environment...")
|
||||
|
||||
# # Create a simple dummy encoder
|
||||
# class DummyEncoder(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.output_dim = obs_dim
|
||||
# self.parameters_to_optimize = []
|
||||
|
||||
# def forward(self, obs):
|
||||
# # Just return a random tensor of the right shape
|
||||
# # In practice, this would encode the observations
|
||||
# return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# # Create critic heads
|
||||
# print(f"Creating {num_critics} critic heads...")
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# critic_heads = [
|
||||
# CriticHead(
|
||||
# input_dim=obs_dim + action_dim,
|
||||
# hidden_dims=hidden_dims,
|
||||
# ).to(device)
|
||||
# for _ in range(num_critics)
|
||||
# ]
|
||||
|
||||
# # Create the critic ensemble
|
||||
# print("Creating CriticEnsemble...")
|
||||
# critic_ensemble = CriticEnsemble(
|
||||
# encoder=DummyEncoder().to(device),
|
||||
# ensemble=critic_heads,
|
||||
# output_normalization=nn.Identity(),
|
||||
# ).to(device)
|
||||
|
||||
# # Create random input data
|
||||
# print("Creating input data...")
|
||||
# obs_dict = {
|
||||
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
# }
|
||||
# actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# # Warmup run
|
||||
# print("Warming up...")
|
||||
# _ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# # Time the forward pass
|
||||
# print(f"Running benchmark with {num_iterations} iterations...")
|
||||
# start_time = time.perf_counter()
|
||||
# for _ in range(num_iterations):
|
||||
# q_values = critic_ensemble(obs_dict, actions)
|
||||
# end_time = time.perf_counter()
|
||||
|
||||
# # Print results
|
||||
# elapsed_time = end_time - start_time
|
||||
# print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
# print("\nVerifying critic outputs are different:")
|
||||
# for i in range(num_critics):
|
||||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
|
||||
main()
|
||||
|
||||
@@ -130,7 +130,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
|
||||
@@ -520,13 +520,13 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
type=int,
|
||||
default=640,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
type=int,
|
||||
default=480,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
|
||||
@@ -492,13 +492,13 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ class TeleoperateControlConfig(ControlConfig):
|
||||
fps: int | None = None
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
display_data: bool = False
|
||||
|
||||
|
||||
@ControlConfig.register_subclass("record")
|
||||
@@ -82,7 +82,7 @@ class RecordControlConfig(ControlConfig):
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Display all cameras on screen
|
||||
display_cameras: bool = True
|
||||
display_data: bool = False
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
@@ -118,6 +118,11 @@ class ReplayControlConfig(ControlConfig):
|
||||
@dataclass
|
||||
class RemoteRobotConfig(ControlConfig):
|
||||
log_interval: int = 100
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp)
|
||||
viewer_ip: str | None = None
|
||||
viewer_port: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -24,7 +24,7 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
@@ -190,13 +190,13 @@ def warmup_record(
|
||||
events,
|
||||
enable_teleoperation,
|
||||
warmup_time_s,
|
||||
display_cameras,
|
||||
display_data,
|
||||
fps,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=warmup_time_s,
|
||||
display_cameras=display_cameras,
|
||||
display_data=display_data,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=enable_teleoperation,
|
||||
@@ -208,7 +208,7 @@ def record_episode(
|
||||
dataset,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
display_data,
|
||||
policy,
|
||||
fps,
|
||||
single_task,
|
||||
@@ -216,12 +216,11 @@ def record_episode(
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
display_data=display_data,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
fps=fps,
|
||||
# record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
@@ -232,7 +231,7 @@ def control_loop(
|
||||
robot,
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy: PreTrainedPolicy = None,
|
||||
@@ -267,8 +266,6 @@ def control_loop(
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
# if record_delta_actions:
|
||||
# action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
@@ -288,14 +285,15 @@ def control_loop(
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if frame["next.done"]:
|
||||
# break
|
||||
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
||||
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
||||
for k, v in action.items():
|
||||
for i, vv in enumerate(v):
|
||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
rr.log(key, rr.Image(observation[key].numpy()), static=True)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
@@ -337,12 +335,8 @@ def reset_follower_position(robot: Robot, target_position):
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
if not is_headless() and listener is not None:
|
||||
listener.stop()
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
||||
@@ -94,7 +94,7 @@ class MetricsTracker:
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
):
|
||||
self.__dict__.update({k: None for k in self.__keys__})
|
||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
|
||||
@@ -135,15 +135,19 @@ python lerobot/scripts/control_robot.py \
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
|
||||
import rerun as rr
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlConfig,
|
||||
ControlPipelineConfig,
|
||||
RecordControlConfig,
|
||||
RemoteRobotConfig,
|
||||
@@ -153,6 +157,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
@@ -232,7 +237,7 @@ def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
|
||||
control_time_s=cfg.teleop_time_s,
|
||||
fps=cfg.fps,
|
||||
teleoperate=True,
|
||||
display_cameras=cfg.display_cameras,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -287,7 +292,7 @@ def record(
|
||||
events,
|
||||
enable_teleoperation,
|
||||
cfg.warmup_time_s,
|
||||
cfg.display_cameras,
|
||||
cfg.display_data,
|
||||
cfg.fps,
|
||||
)
|
||||
|
||||
@@ -305,7 +310,7 @@ def record(
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
display_data=cfg.display_data,
|
||||
policy=policy,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
@@ -335,7 +340,7 @@ def record(
|
||||
break
|
||||
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
stop_recording(robot, listener, cfg.display_data)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
@@ -363,8 +368,6 @@ def replay(
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
# if replay_delta_actions:
|
||||
# action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
@@ -374,6 +377,40 @@ def replay(
|
||||
log_control_info(robot, dt_s, fps=cfg.fps)
|
||||
|
||||
|
||||
def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop.
|
||||
|
||||
Args:
|
||||
control_config: Configuration determining data display and robot type.
|
||||
session_name: Rerun session name. Defaults to "lerobot_control_loop".
|
||||
|
||||
Raises:
|
||||
ValueError: If viewer IP is missing for non-remote configurations with display enabled.
|
||||
"""
|
||||
if (control_config.display_data and not is_headless()) or (
|
||||
control_config.display_data and isinstance(control_config, RemoteRobotConfig)
|
||||
):
|
||||
# Configure Rerun flush batch size default to 8KB if not set
|
||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||
|
||||
# Initialize Rerun based on configuration
|
||||
rr.init(session_name)
|
||||
if isinstance(control_config, RemoteRobotConfig):
|
||||
viewer_ip = control_config.viewer_ip
|
||||
viewer_port = control_config.viewer_port
|
||||
if not viewer_ip or not viewer_port:
|
||||
raise ValueError(
|
||||
"Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data."
|
||||
)
|
||||
logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}")
|
||||
rr.connect_tcp(f"{viewer_ip}:{viewer_port}")
|
||||
else:
|
||||
# Get memory limit for rerun viewer parameters
|
||||
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def control_robot(cfg: ControlPipelineConfig):
|
||||
init_logging()
|
||||
@@ -381,17 +418,22 @@ def control_robot(cfg: ControlPipelineConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
# TODO(Steven): Blueprint for fixed window size
|
||||
|
||||
if isinstance(cfg.control, CalibrateControlConfig):
|
||||
calibrate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, TeleoperateControlConfig):
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop")
|
||||
teleoperate(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RecordControlConfig):
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record")
|
||||
record(robot, cfg.control)
|
||||
elif isinstance(cfg.control, ReplayControlConfig):
|
||||
replay(robot, cfg.control)
|
||||
elif isinstance(cfg.control, RemoteRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi
|
||||
|
||||
_init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote")
|
||||
run_lekiwi(cfg.robot)
|
||||
|
||||
if robot.is_connected:
|
||||
|
||||
@@ -66,7 +66,7 @@ from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
@@ -124,7 +124,6 @@ def rollout(
|
||||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
@@ -145,6 +144,7 @@ def rollout(
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done):
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
@@ -155,6 +155,10 @@ def rollout(
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
|
||||
@@ -231,6 +231,7 @@ def act_with_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
@@ -78,9 +78,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(
|
||||
val, device=device, non_blocking=non_blocking
|
||||
)
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
@@ -505,7 +503,6 @@ class ReplayBuffer:
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
action_delta: Optional[float] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
@@ -520,7 +517,6 @@ class ReplayBuffer:
|
||||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||
action_delta (Optional[float]): Factor to divide actions by.
|
||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
@@ -565,9 +561,6 @@ class ReplayBuffer:
|
||||
else:
|
||||
first_action = first_action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
first_action = first_action / action_delta
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
@@ -598,9 +591,6 @@ class ReplayBuffer:
|
||||
else:
|
||||
action = action[:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
action = action / action_delta
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
|
||||
@@ -42,7 +42,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
self,
|
||||
robot,
|
||||
use_delta_action_space: bool = True,
|
||||
delta: float | None = None,
|
||||
display_cameras: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -55,8 +54,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
robot: The robot interface object used to connect and interact with the physical robot.
|
||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||
joint positions are used.
|
||||
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
||||
0 and 1 when using a delta action space.
|
||||
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -74,7 +71,6 @@ class HILSerlRobotEnv(gym.Env):
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
self.delta = delta
|
||||
self.use_delta_action_space = use_delta_action_space
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
@@ -374,7 +370,7 @@ class RewardWrapper(gym.Wrapper):
|
||||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
@@ -382,15 +378,17 @@ class RewardWrapper(gym.Wrapper):
|
||||
]
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
success = (
|
||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
if reward == 1.0:
|
||||
if success == 1.0:
|
||||
terminated = True
|
||||
reward = 1.0
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
@@ -720,11 +718,13 @@ class ResetWrapper(gym.Wrapper):
|
||||
env: HILSerlRobotEnv,
|
||||
reset_pose: np.ndarray | None = None,
|
||||
reset_time_s: float = 5,
|
||||
open_gripper_on_reset: bool = False,
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_time_s = reset_time_s
|
||||
self.reset_pose = reset_pose
|
||||
self.robot = self.unwrapped.robot
|
||||
self.open_gripper_on_reset = open_gripper_on_reset
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
if self.reset_pose is not None:
|
||||
@@ -733,6 +733,14 @@ class ResetWrapper(gym.Wrapper):
|
||||
reset_follower_position(self.robot, self.reset_pose)
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
if self.open_gripper_on_reset:
|
||||
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
|
||||
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.1)
|
||||
current_joint_pos[-1] = 0.0
|
||||
self.robot.send_action(torch.from_numpy(current_joint_pos))
|
||||
busy_wait(0.2)
|
||||
else:
|
||||
log_say(
|
||||
f"Manually reset the environment for {self.reset_time_s} seconds.",
|
||||
@@ -762,37 +770,48 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||
|
||||
|
||||
class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
def __init__(self, env, penalty: float = -0.1):
|
||||
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True):
|
||||
super().__init__(env)
|
||||
self.penalty = penalty
|
||||
self.gripper_penalty_in_reward = gripper_penalty_in_reward
|
||||
self.last_gripper_state = None
|
||||
|
||||
def reward(self, reward, action):
|
||||
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
|
||||
|
||||
if isinstance(action, tuple):
|
||||
action = action[0]
|
||||
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
|
||||
action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
|
||||
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
|
||||
gripper_state_normalized > 0.9 and action_normalized < 0.1
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and action_normalized < -0.5
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
return reward + self.penalty * gripper_penalty_bool
|
||||
return reward + self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
def step(self, action):
|
||||
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
if isinstance(action, tuple):
|
||||
gripper_action = action[0][-1]
|
||||
else:
|
||||
gripper_action = action[-1]
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = self.reward(reward, action)
|
||||
gripper_penalty = self.reward(reward, gripper_action)
|
||||
|
||||
if self.gripper_penalty_in_reward:
|
||||
reward += gripper_penalty
|
||||
else:
|
||||
info["gripper_penalty"] = gripper_penalty
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.last_gripper_state = None
|
||||
return super().reset(**kwargs)
|
||||
obs, info = super().reset(**kwargs)
|
||||
if self.gripper_penalty_in_reward:
|
||||
info["gripper_penalty"] = 0.0
|
||||
return obs, info
|
||||
|
||||
|
||||
class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||
class GripperActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
super().__init__(env)
|
||||
self.quantization_threshold = quantization_threshold
|
||||
@@ -801,16 +820,18 @@ class GripperQuantizationWrapper(gym.ActionWrapper):
|
||||
is_intervention = False
|
||||
if isinstance(action, tuple):
|
||||
action, is_intervention = action
|
||||
|
||||
gripper_command = action[-1]
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
if gripper_command < -self.quantization_threshold:
|
||||
gripper_command = -MAX_GRIPPER_COMMAND
|
||||
elif gripper_command > self.quantization_threshold:
|
||||
gripper_command = MAX_GRIPPER_COMMAND
|
||||
else:
|
||||
gripper_command = 0.0
|
||||
|
||||
# Gripper actions are between 0, 2
|
||||
# we want to quantize them to -1, 0 or 1
|
||||
gripper_command = gripper_command - 1.0
|
||||
|
||||
if self.quantization_threshold is not None:
|
||||
# Quantize gripper command to -1, 0 or 1
|
||||
gripper_command = (
|
||||
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
|
||||
)
|
||||
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
|
||||
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
|
||||
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
|
||||
action[-1] = gripper_action.item()
|
||||
@@ -836,10 +857,12 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||
]
|
||||
)
|
||||
if self.use_gripper:
|
||||
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
|
||||
# gripper actions open at 2.0, and closed at 0.0
|
||||
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
|
||||
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
|
||||
ee_action_space = gym.spaces.Box(
|
||||
low=-action_space_bounds,
|
||||
high=action_space_bounds,
|
||||
low=min_action_space_bounds,
|
||||
high=max_action_space_bounds,
|
||||
shape=(3 + int(self.use_gripper),),
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -997,11 +1020,11 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
if self.use_gripper:
|
||||
gripper_command = self.controller.gripper_command()
|
||||
if gripper_command == "open":
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
gamepad_action = np.concatenate([gamepad_action, [2.0]])
|
||||
elif gripper_command == "close":
|
||||
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [0.0]])
|
||||
else:
|
||||
gamepad_action = np.concatenate([gamepad_action, [1.0]])
|
||||
|
||||
# Check episode ending buttons
|
||||
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
|
||||
@@ -1141,7 +1164,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
display_cameras=cfg.wrapper.display_cameras,
|
||||
delta=cfg.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.wrapper.use_relative_joint_positions
|
||||
and cfg.wrapper.ee_action_space_params is None,
|
||||
)
|
||||
@@ -1165,10 +1187,13 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.wrapper.use_gripper:
|
||||
env = GripperQuantizationWrapper(
|
||||
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
|
||||
)
|
||||
# env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
|
||||
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
|
||||
if cfg.wrapper.gripper_penalty is not None:
|
||||
env = GripperPenaltyWrapper(
|
||||
env=env,
|
||||
penalty=cfg.wrapper.gripper_penalty,
|
||||
gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
@@ -1176,6 +1201,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
ee_action_space_params=cfg.wrapper.ee_action_space_params,
|
||||
use_gripper=cfg.wrapper.use_gripper,
|
||||
)
|
||||
|
||||
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
|
||||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
|
||||
env = GamepadControlWrapper(
|
||||
@@ -1192,6 +1218,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
env=env,
|
||||
reset_pose=cfg.wrapper.fixed_reset_joint_positions,
|
||||
reset_time_s=cfg.wrapper.reset_time_s,
|
||||
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset,
|
||||
)
|
||||
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
|
||||
@@ -1341,11 +1368,10 @@ def record_dataset(env, policy, cfg):
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def replay_episode(env, repo_id, root=None, episode=0):
|
||||
def replay_episode(env, cfg):
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
local_files_only = root is not None
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
|
||||
env.reset()
|
||||
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
@@ -1353,7 +1379,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"][:4]
|
||||
action = actions[idx]["action"]
|
||||
env.step((action, False))
|
||||
# env.step((action / env.unwrapped.delta, False))
|
||||
|
||||
@@ -1384,9 +1410,7 @@ def main(cfg: EnvConfig):
|
||||
if cfg.mode == "replay":
|
||||
replay_episode(
|
||||
env,
|
||||
cfg.replay_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
episode=cfg.replay_episode,
|
||||
cfg=cfg,
|
||||
)
|
||||
exit()
|
||||
|
||||
|
||||
@@ -407,6 +407,7 @@ def add_actor_information_and_train(
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
@@ -428,7 +429,7 @@ def add_actor_information_and_train(
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -492,7 +493,7 @@ def add_actor_information_and_train(
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
@@ -509,7 +510,7 @@ def add_actor_information_and_train(
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
@@ -771,17 +772,18 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters_to_optimize, lr=cfg.policy.critic_lr
|
||||
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
@@ -992,7 +994,6 @@ def initialize_offline_replay_buffer(
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
@@ -1026,8 +1027,10 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
@@ -1089,6 +1092,44 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||
"""
|
||||
Checks whether each parameter in the module has a gradient.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||
otherwise False.
|
||||
"""
|
||||
grad_status = {}
|
||||
for name, param in module.named_parameters():
|
||||
grad_status[name] = param.grad is not None
|
||||
return grad_status
|
||||
|
||||
|
||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||
"""
|
||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||
|
||||
Args:
|
||||
actor (nn.Module): The actor model.
|
||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||
whether each parameter has a gradient.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||
"""
|
||||
# Get actor parameter names as a set.
|
||||
model_param_names = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# Intersect parameter names between actor and grad_status.
|
||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||
return overlapping
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
@@ -211,7 +211,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
|
||||
@@ -179,7 +179,9 @@ def run_server(
|
||||
]
|
||||
videos_info = [
|
||||
{
|
||||
"url": url_for("static", filename=video_path),
|
||||
"url": url_for(
|
||||
"static", filename=str(video_path).replace("\\", "/")
|
||||
),
|
||||
"filename": video_path.parent.name,
|
||||
}
|
||||
for video_path in video_paths
|
||||
@@ -391,7 +393,7 @@ def visualize_dataset_html(
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix())
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
@@ -61,9 +61,9 @@ dependencies = [
|
||||
"jsonlines>=4.0.0",
|
||||
"numba>=0.59.0",
|
||||
"omegaconf>=2.3.0",
|
||||
"opencv-python>=4.9.0",
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"av>=12.0.5,<13.0.0",
|
||||
"av>=12.0.5",
|
||||
"protobuf>=5.29.3",
|
||||
"pymunk>=6.6.0",
|
||||
"pynput>=1.7.7",
|
||||
@@ -71,7 +71,7 @@ dependencies = [
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l'))",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchmetrics>=1.6.0",
|
||||
"torchvision>=0.21.0",
|
||||
"transformers>=4.47.0",
|
||||
|
||||
@@ -172,8 +172,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
push_to_hub=False,
|
||||
# TODO(rcadene, aliberts): test video=True
|
||||
video=False,
|
||||
# TODO(rcadene): display cameras through cv2 sometimes crashes on mac
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
@@ -226,7 +225,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
@@ -273,7 +272,7 @@ def test_resume_record(tmp_path, request, robot_type, mock):
|
||||
episode_time_s=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_episodes=1,
|
||||
)
|
||||
@@ -330,7 +329,7 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
@@ -380,7 +379,7 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
|
||||
@@ -434,7 +433,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
display_data=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user