forked from tangger/lerobot
Compare commits
15 Commits
fix/lint_w
...
user/pepij
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1892aa1b08 | ||
|
|
3b6fff70e1 | ||
|
|
6e97876e81 | ||
|
|
4bdbf2f6e0 | ||
|
|
4e9b4dd380 | ||
|
|
17d12db7c4 | ||
|
|
6a8be97bb5 | ||
|
|
841d54c050 | ||
|
|
e3c3c165aa | ||
|
|
f994febca4 | ||
|
|
12f52632ed | ||
|
|
8a64d8268b | ||
|
|
84565c7c2e | ||
|
|
05b54733da | ||
|
|
513b008bcc |
@@ -29,7 +29,7 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.0
|
||||
rev: v1
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
@@ -38,7 +38,7 @@ repos:
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.9
|
||||
rev: v0.9.10
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
@@ -232,8 +232,8 @@ python lerobot/scripts/eval.py \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=10 \
|
||||
--use_amp=false \
|
||||
--device=cuda
|
||||
--policy.use_amp=false \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
@@ -454,8 +454,8 @@ Next, you'll need to calibrate your SO-100 robot to ensure that the leader and f
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
@@ -470,8 +470,8 @@ python lerobot/scripts/control_robot.py \
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
@@ -571,14 +571,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so100_test \
|
||||
--job_name=act_so100_test \
|
||||
--device=cuda \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
@@ -366,8 +366,8 @@ Now we have to calibrate the leader arm and the follower arm. The wheel motors d
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
||||
@@ -385,8 +385,8 @@ If you have the **wired** LeKiwi version please run all commands including this
|
||||
### Calibrate leader arm
|
||||
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script (on your laptop/pc) to launch manual calibration:
|
||||
@@ -416,22 +416,22 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||
|------------|-------------------|-----------------------|
|
||||
| Fast | 0.4 | 90 |
|
||||
| Medium | 0.25 | 60 |
|
||||
| Slow | 0.1 | 30 |
|
||||
| ---------- | ------------------ | ---------------------- |
|
||||
| Fast | 0.4 | 90 |
|
||||
| Medium | 0.25 | 60 |
|
||||
| Slow | 0.1 | 30 |
|
||||
|
||||
|
||||
| Key | Action |
|
||||
|------|--------------------------------|
|
||||
| W | Move forward |
|
||||
| A | Move left |
|
||||
| S | Move backward |
|
||||
| D | Move right |
|
||||
| Z | Turn left |
|
||||
| X | Turn right |
|
||||
| R | Increase speed |
|
||||
| F | Decrease speed |
|
||||
| Key | Action |
|
||||
| --- | -------------- |
|
||||
| W | Move forward |
|
||||
| A | Move left |
|
||||
| S | Move backward |
|
||||
| D | Move right |
|
||||
| Z | Turn left |
|
||||
| X | Turn right |
|
||||
| R | Increase speed |
|
||||
| F | Decrease speed |
|
||||
|
||||
> [!TIP]
|
||||
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||
@@ -549,14 +549,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_lekiwi_test \
|
||||
--job_name=act_lekiwi_test \
|
||||
--device=cuda \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
||||
|
||||
@@ -176,8 +176,8 @@ Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
@@ -192,8 +192,8 @@ python lerobot/scripts/control_robot.py \
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
@@ -293,14 +293,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_moss_test \
|
||||
--job_name=act_moss_test \
|
||||
--device=cuda \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
|
||||
|
||||
## The training script
|
||||
|
||||
@@ -386,14 +386,14 @@ When you connect your robot for the first time, the [`ManipulatorRobot`](../lero
|
||||
|
||||
Here are the positions you'll move the follower arm to:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
And here are the corresponding positions for the leader arm:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
@@ -898,14 +898,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_koch_test \
|
||||
--job_name=act_koch_test \
|
||||
--device=cuda \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
@@ -135,14 +135,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_aloha_test \
|
||||
--job_name=act_aloha_test \
|
||||
--device=cuda \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
@@ -749,6 +749,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
# Add global index of frame (indices)
|
||||
item["indices"] = torch.tensor(idx)
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -13,9 +13,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterator, Union
|
||||
import random
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
@@ -59,3 +61,123 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class SumTree:
|
||||
"""
|
||||
A classic sum-tree data structure for storing priorities.
|
||||
Each leaf stores a sample's priority, and internal nodes store sums of children.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
"""
|
||||
Args:
|
||||
capacity: Maximum number of elements.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.size = capacity
|
||||
self.tree = [0.0] * (2 * self.size)
|
||||
|
||||
def initialize_tree(self, priorities: List[float]):
|
||||
"""
|
||||
Initializes the sum tree
|
||||
"""
|
||||
# Set leaf values
|
||||
for i, priority in enumerate(priorities):
|
||||
self.tree[i + self.size] = priority
|
||||
|
||||
# Compute internal node values
|
||||
for i in range(self.size - 1, 0, -1):
|
||||
self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
|
||||
|
||||
def update(self, idx: int, priority: float):
|
||||
"""
|
||||
Update the priority at leaf index `idx` and propagate changes upwards.
|
||||
"""
|
||||
tree_idx = idx + self.size
|
||||
self.tree[tree_idx] = priority # Set new priority
|
||||
|
||||
# Propagate up, explicitly summing children
|
||||
tree_idx //= 2
|
||||
while tree_idx >= 1:
|
||||
self.tree[tree_idx] = self.tree[2 * tree_idx] + self.tree[2 * tree_idx + 1]
|
||||
tree_idx //= 2
|
||||
|
||||
def total_priority(self) -> float:
|
||||
"""Returns the sum of all priorities (stored at root)."""
|
||||
return self.tree[1]
|
||||
|
||||
def sample(self, value: float) -> int:
|
||||
"""
|
||||
Samples an index where the prefix sum up to that leaf is >= `value`.
|
||||
"""
|
||||
value = min(max(value, 0), self.total_priority()) # Clamp value
|
||||
idx = 1
|
||||
while idx < self.size:
|
||||
left = 2 * idx
|
||||
if self.tree[left] >= value:
|
||||
idx = left
|
||||
else:
|
||||
value -= self.tree[left]
|
||||
idx = left + 1
|
||||
return idx - self.size # Convert tree index to data index
|
||||
|
||||
|
||||
class PrioritizedSampler(Sampler[int]):
|
||||
"""
|
||||
PyTorch Sampler that draws samples in proportion to their priority using a SumTree.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_len: int,
|
||||
alpha: float = 0.6,
|
||||
eps: float = 1e-6,
|
||||
num_samples_per_epoch: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_len: Total number of samples in the dataset.
|
||||
alpha: Exponent for priority scaling. Default is 0.6.
|
||||
eps: Small constant to avoid zero priorities.
|
||||
num_samples_per_epoch: Number of samples per epoch (default is data_len).
|
||||
"""
|
||||
self.data_len = data_len
|
||||
self.alpha = alpha
|
||||
self.eps = eps
|
||||
self.num_samples_per_epoch = num_samples_per_epoch or data_len
|
||||
|
||||
# Initialize difficulties and sum-tree
|
||||
self.difficulties = [1.0] * data_len
|
||||
self.priorities = [0.0] * data_len
|
||||
initial_priorities = [(1.0 + eps) ** alpha] * data_len
|
||||
|
||||
self.sumtree = SumTree(data_len)
|
||||
self.sumtree.initialize_tree(initial_priorities)
|
||||
for i, p in enumerate(initial_priorities):
|
||||
self.priorities[i] = p
|
||||
|
||||
def update_priorities(self, indices: List[int], difficulties: List[float]):
|
||||
"""
|
||||
Updates the priorities in the sum-tree.
|
||||
"""
|
||||
for idx, diff in zip(indices, difficulties, strict=False):
|
||||
self.difficulties[idx] = diff
|
||||
new_priority = (diff + self.eps) ** self.alpha
|
||||
self.priorities[idx] = new_priority
|
||||
self.sumtree.update(idx, new_priority)
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""
|
||||
Samples indices based on their priority weights.
|
||||
"""
|
||||
total_p = self.sumtree.total_priority()
|
||||
|
||||
for _ in range(self.num_samples_per_epoch):
|
||||
r = random.random() * total_p
|
||||
idx = self.sumtree.sample(r)
|
||||
|
||||
yield idx
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples_per_epoch
|
||||
|
||||
@@ -155,11 +155,14 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
elementwise_l1 = F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch[
|
||||
"action_is_pad"
|
||||
].unsqueeze(-1)
|
||||
|
||||
l1_loss = elementwise_l1.mean()
|
||||
|
||||
l1_per_sample = elementwise_l1.mean(dim=(1, 2))
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
@@ -168,9 +171,17 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict = {
|
||||
"l1_loss": l1_loss.item(),
|
||||
"kld_loss": mean_kld.item(),
|
||||
"per_sample_l1": l1_per_sample,
|
||||
}
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict = {
|
||||
"l1_loss": l1_loss.item(),
|
||||
"per_sample_l1": l1_per_sample,
|
||||
}
|
||||
loss = l1_loss
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -114,7 +114,7 @@ def save_images_from_cameras(
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -224,9 +224,20 @@ class IntelRealSenseCamera:
|
||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||
else:
|
||||
self.serial_number = config.serial_number
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.use_depth = config.use_depth
|
||||
@@ -246,7 +257,6 @@ class IntelRealSenseCamera:
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(alibets): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -284,15 +294,19 @@ class IntelRealSenseCamera:
|
||||
config = rs.config()
|
||||
config.enable_device(str(self.serial_number))
|
||||
|
||||
if self.fps and self.width and self.height:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
||||
if self.use_depth:
|
||||
if self.fps and self.width and self.height:
|
||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
||||
@@ -330,18 +344,18 @@ class IntelRealSenseCamera:
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and self.width != actual_width:
|
||||
if self.capture_width is not None and self.capture_width != actual_width:
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and self.height != actual_height:
|
||||
if self.capture_height is not None and self.capture_height != actual_height:
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
@@ -387,7 +401,7 @@ class IntelRealSenseCamera:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
@@ -409,7 +423,7 @@ class IntelRealSenseCamera:
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
h, w = depth_map.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
|
||||
@@ -144,8 +144,8 @@ def save_images_from_cameras(
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
|
||||
f"height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -244,9 +244,19 @@ class OpenCVCamera:
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.mock = config.mock
|
||||
@@ -263,7 +273,6 @@ class OpenCVCamera:
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(aliberts): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -325,10 +334,10 @@ class OpenCVCamera:
|
||||
|
||||
if self.fps is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
if self.width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
||||
if self.height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
||||
if self.capture_width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
|
||||
if self.capture_height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
|
||||
|
||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
@@ -340,19 +349,22 @@ class OpenCVCamera:
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
|
||||
if self.capture_width is not None and not math.isclose(
|
||||
self.capture_width, actual_width, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
|
||||
if self.capture_height is not None and not math.isclose(
|
||||
self.capture_height, actual_height, rel_tol=1e-3
|
||||
):
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
self.is_connected = True
|
||||
|
||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||
@@ -393,7 +405,7 @@ class OpenCVCamera:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
|
||||
@@ -11,7 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from functools import wraps
|
||||
@@ -23,6 +25,7 @@ import draccus
|
||||
from lerobot.common.utils.utils import has_method
|
||||
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
@@ -58,6 +61,86 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||
"""Parse plugin-related arguments from command-line arguments.
|
||||
|
||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||
It processes arguments in the format '--key=value' and returns them as a dictionary.
|
||||
|
||||
Args:
|
||||
plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
|
||||
cli_args (Sequence[str]): A sequence of command-line arguments to parse.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the parsed plugin arguments where:
|
||||
- Keys are the argument names (with '--' prefix removed if present)
|
||||
- Values are the corresponding argument values
|
||||
|
||||
Example:
|
||||
>>> args = ['--env.discover_packages_path=my_package',
|
||||
... '--other_arg=value']
|
||||
>>> parse_plugin_args('discover_packages_path', args)
|
||||
{'env.discover_packages_path': 'my_package'}
|
||||
"""
|
||||
plugin_args = {}
|
||||
for arg in args:
|
||||
if "=" in arg and plugin_arg_suffix in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
# Remove leading '--' if present
|
||||
if key.startswith("--"):
|
||||
key = key[2:]
|
||||
plugin_args[key] = value
|
||||
return plugin_args
|
||||
|
||||
|
||||
class PluginLoadError(Exception):
|
||||
"""Raised when a plugin fails to load."""
|
||||
|
||||
|
||||
def load_plugin(plugin_path: str) -> None:
|
||||
"""Load and initialize a plugin from a given Python package path.
|
||||
|
||||
This function attempts to load a plugin by importing its package and any submodules.
|
||||
Plugin registration is expected to happen during package initialization, i.e. when
|
||||
the package is imported the gym environment should be registered and the config classes
|
||||
registered with their parents using the `register_subclass` decorator.
|
||||
|
||||
Args:
|
||||
plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
|
||||
|
||||
Raises:
|
||||
PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
|
||||
|
||||
Examples:
|
||||
>>> load_plugin("external_plugin.core") # Loads plugin from external package
|
||||
|
||||
Notes:
|
||||
- The plugin package should handle its own registration during import
|
||||
- All submodules in the plugin package will be imported
|
||||
- Implementation follows the plugin discovery pattern from Python packaging guidelines
|
||||
|
||||
See Also:
|
||||
https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
|
||||
"""
|
||||
try:
|
||||
package_module = importlib.import_module(plugin_path, __package__)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
raise PluginLoadError(
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
def iter_namespace(ns_pkg):
|
||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||
|
||||
try:
|
||||
for _finder, pkg_name, _ispkg in iter_namespace(package_module):
|
||||
importlib.import_module(pkg_name)
|
||||
except ImportError as e:
|
||||
raise PluginLoadError(
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
|
||||
@@ -105,10 +188,13 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does two additional things:
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
||||
initialize it from there to allow to fetch configs from the hub directly
|
||||
- Will load plugins specified in the CLI arguments. These plugins will typically register
|
||||
their own subclasses of config classes, so that draccus can find the right class to instantiate
|
||||
from the CLI '.type' arguments
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
@@ -121,6 +207,14 @@ def wrap(config_path: Path | None = None):
|
||||
args = args[1:]
|
||||
else:
|
||||
cli_args = sys.argv[1:]
|
||||
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
|
||||
for plugin_cli_arg, plugin_path in plugin_args.items():
|
||||
try:
|
||||
load_plugin(plugin_path)
|
||||
except PluginLoadError as e:
|
||||
# add the relevant CLI arg to the error message
|
||||
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
||||
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
|
||||
@@ -25,7 +25,7 @@ from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler, PrioritizedSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
@@ -70,6 +70,7 @@ def update_policy(
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
@@ -126,6 +127,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
data_len = len(dataset)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
@@ -174,6 +176,15 @@ def train(cfg: TrainPipelineConfig):
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
# TODO(pepijn): If experiment works integrate this
|
||||
shuffle = False
|
||||
sampler = PrioritizedSampler(
|
||||
data_len=data_len,
|
||||
alpha=0.6,
|
||||
eps=1e-6,
|
||||
num_samples_per_epoch=data_len,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
@@ -220,6 +231,12 @@ def train(cfg: TrainPipelineConfig):
|
||||
use_amp=cfg.policy.use_amp,
|
||||
)
|
||||
|
||||
# Update sampler
|
||||
if "indices" in batch and "per_sample_l1" in output_dict:
|
||||
idxs = batch["indices"].cpu().tolist()
|
||||
diffs = output_dict["per_sample_l1"].detach().cpu().tolist()
|
||||
sampler.update_priorities(idxs, diffs)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
|
||||
@@ -56,7 +56,6 @@ dependencies = [
|
||||
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
"h5py>=3.10.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
||||
"hydra-core>=1.3.2",
|
||||
"imageio[ffmpeg]>=2.34.0",
|
||||
"jsonlines>=4.0.0",
|
||||
"numba>=0.59.0",
|
||||
|
||||
89
tests/configs/test_plugin_loading.py
Normal file
89
tests/configs/test_plugin_loading.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
|
||||
|
||||
|
||||
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
|
||||
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
|
||||
return f"""
|
||||
from dataclasses import dataclass
|
||||
from lerobot.common.envs.configs import {base_class}
|
||||
|
||||
@{base_class}.register_subclass("{plugin_name}")
|
||||
@dataclass
|
||||
class TestPluginConfig:
|
||||
value: int = 42
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]:
|
||||
"""Creates a temporary plugin package structure."""
|
||||
plugin_pkg = tmp_path / "test_plugin"
|
||||
plugin_pkg.mkdir()
|
||||
(plugin_pkg / "__init__.py").touch()
|
||||
|
||||
with open(plugin_pkg / "my_plugin.py", "w") as f:
|
||||
f.write(create_plugin_code())
|
||||
|
||||
# Add tmp_path to Python path so we can import from it
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
yield plugin_pkg
|
||||
sys.path.pop(0)
|
||||
|
||||
|
||||
def test_parse_plugin_args():
|
||||
cli_args = [
|
||||
"--env.type=test",
|
||||
"--model.discover_packages_path=some.package",
|
||||
"--env.discover_packages_path=other.package",
|
||||
]
|
||||
plugin_args = parse_plugin_args("discover_packages_path", cli_args)
|
||||
assert plugin_args == {
|
||||
"model.discover_packages_path": "some.package",
|
||||
"env.discover_packages_path": "other.package",
|
||||
}
|
||||
|
||||
|
||||
def test_load_plugin_success(plugin_dir: Path):
|
||||
# Import should work and register the plugin with the real EnvConfig
|
||||
load_plugin("test_plugin")
|
||||
|
||||
assert "test_env" in EnvConfig.get_known_choices()
|
||||
plugin_cls = EnvConfig.get_choice_class("test_env")
|
||||
plugin_instance = plugin_cls()
|
||||
assert plugin_instance.value == 42
|
||||
|
||||
|
||||
def test_load_plugin_failure():
|
||||
with pytest.raises(PluginLoadError) as exc_info:
|
||||
load_plugin("nonexistent_plugin")
|
||||
assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_wrap_with_plugin(plugin_dir: Path):
|
||||
@dataclass
|
||||
class Config:
|
||||
env: EnvConfig
|
||||
|
||||
@wrap()
|
||||
def dummy_func(cfg: Config):
|
||||
return cfg
|
||||
|
||||
# Test loading plugin via CLI args
|
||||
sys.argv = [
|
||||
"dummy_script.py",
|
||||
"--env.discover_packages_path=test_plugin",
|
||||
"--env.type=test_env",
|
||||
]
|
||||
|
||||
cfg = dummy_func()
|
||||
assert isinstance(cfg, Config)
|
||||
assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env"))
|
||||
assert cfg.env.value == 42
|
||||
@@ -85,8 +85,8 @@ def test_camera(request, camera_type, mock):
|
||||
camera.connect()
|
||||
assert camera.is_connected
|
||||
assert camera.fps is not None
|
||||
assert camera.width is not None
|
||||
assert camera.height is not None
|
||||
assert camera.capture_width is not None
|
||||
assert camera.capture_height is not None
|
||||
|
||||
# Test connecting twice raises an error
|
||||
with pytest.raises(RobotDeviceAlreadyConnectedError):
|
||||
@@ -204,3 +204,49 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
|
||||
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
|
||||
@require_camera
|
||||
def test_camera_rotation(request, camera_type, mock):
|
||||
config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
|
||||
|
||||
# No rotation.
|
||||
camera = make_camera(**config_kwargs, rotation=None)
|
||||
camera.connect()
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
no_rot_img = camera.read()
|
||||
h, w, c = no_rot_img.shape
|
||||
assert h == 480 and w == 640 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
# Rotation = 90 (clockwise).
|
||||
camera = make_camera(**config_kwargs, rotation=90)
|
||||
camera.connect()
|
||||
# With a 90° rotation, we expect the metadata dimensions to be swapped.
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 480
|
||||
assert camera.height == 640
|
||||
import cv2
|
||||
|
||||
assert camera.rotation == cv2.ROTATE_90_CLOCKWISE
|
||||
rot_img = camera.read()
|
||||
h, w, c = rot_img.shape
|
||||
assert h == 640 and w == 480 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
# Rotation = 180.
|
||||
camera = make_camera(**config_kwargs, rotation=None)
|
||||
camera.connect()
|
||||
assert camera.capture_width == 640
|
||||
assert camera.capture_height == 480
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
no_rot_img = camera.read()
|
||||
h, w, c = no_rot_img.shape
|
||||
assert h == 480 and w == 640 and c == 3
|
||||
camera.disconnect()
|
||||
|
||||
@@ -368,7 +368,7 @@ def test_normalize(insert_temporal_dim):
|
||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||
# to test with `policy.use_mpc=false`.
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||
|
||||
Reference in New Issue
Block a user