Compare commits
24 Commits
user/miche
...
user/alibe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1386fd79e | ||
|
|
b47620cd59 | ||
|
|
a32d988536 | ||
|
|
9571a713df | ||
|
|
b418409b24 | ||
|
|
0a6b3992ee | ||
|
|
e6d19116c4 | ||
|
|
92ea7fc0fb | ||
|
|
46cd157c55 | ||
|
|
52028f5201 | ||
|
|
f5b1ef0045 | ||
|
|
81a4deadc3 | ||
|
|
fef83ce349 | ||
|
|
eb3986e131 | ||
|
|
d45226ad06 | ||
|
|
fe43f93553 | ||
|
|
40e0a311b5 | ||
|
|
13677cb720 | ||
|
|
247d493d06 | ||
|
|
2f00475fc6 | ||
|
|
4687296d93 | ||
|
|
5c2f8ccd14 | ||
|
|
d25e3bd989 | ||
|
|
bfd26eef5a |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,7 +29,6 @@ outputs
|
||||
|
||||
# VS Code
|
||||
.vscode
|
||||
.devcontainer
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
|
||||
@@ -46,6 +46,7 @@ repos:
|
||||
rev: v3.20.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.11
|
||||
hooks:
|
||||
|
||||
13
README.md
13
README.md
@@ -408,19 +408,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
- [HIL-SERL](https://hil-serl.github.io/)
|
||||
```bibtex
|
||||
@Article{luo2024hilserl,
|
||||
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
|
||||
author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine},
|
||||
year={2024},
|
||||
eprint={2410.21845},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}
|
||||
```
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -22,7 +22,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
|
||||
@@ -9,8 +9,6 @@
|
||||
title: Getting Started with Real-World Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: hilserl
|
||||
title: Getting Started with Reinforcement Learning
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: so101
|
||||
|
||||
@@ -36,16 +36,18 @@ If you haven't yet set up and calibrated your robot and teleop device, please do
|
||||
|
||||
In this example, we’ll demonstrate how to teleoperate the SO101 robot. For each command, we also provide a corresponding API example.
|
||||
|
||||
Note that the `id` associated with a robot is used to store the calibration file. It's important to use the same `id` when teleoperating, recording, and evaluating when using the same setup.
|
||||
|
||||
<hfoptions id="teleoperate_so101">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_red_robot_arm \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_blue_leader_arm
|
||||
--teleop.id=my_awesome_leader_arm
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
@@ -93,11 +95,11 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
python -m lerobot.teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_koch_robot \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_koch_teleop \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -157,13 +159,13 @@ Now you can record a dataset. To record 2 episodes and upload your dataset to th
|
||||
python -m lerobot.record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \
|
||||
--robot.id=my_red_robot_arm \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_blue_leader_arm \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the black cube"
|
||||
```
|
||||
@@ -227,18 +229,6 @@ If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you c
|
||||
echo ${HF_USER}/so101_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can visualize it locally with (via a window in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so101_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
This will launch a local web server that looks like this:
|
||||
<div style="text-align:center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%"></img>
|
||||
</div>
|
||||
|
||||
## Replay an episode
|
||||
|
||||
A useful feature is the `replay` function, which allows you to replay any episode that you've recorded or episodes from any dataset out there. This function helps you test the repeatability of your robot's actions and assess transferability across robots of the same model.
|
||||
@@ -248,9 +238,9 @@ You can replay the first episode on your robot with:
|
||||
python -m lerobot.replay \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.episode=2
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.episode=0 # choose the episode you want to replay
|
||||
```
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
@@ -306,14 +296,14 @@ python -m lerobot.record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=blue_follower_arm \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM0 \
|
||||
--teleop.id=red_leader_arm \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=$HF_USER/eval_lego_${EPOCHREALTIME/[^0-9]/} \
|
||||
--dataset.repo_id=$HF_USER/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--policy.path=${HF_USER}/act_johns_arm
|
||||
--policy.path=${HF_USER}/my_policy
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
|
||||
@@ -1,512 +0,0 @@
|
||||
# HilSerl Real Robot Training Workflow Guide
|
||||
|
||||
Human-in-the-Loop Sample-Efficient Reinforcement Learning (HIL-SERL) with LeRobot workflow for taking a policy from “zero” to real-world robot mastery in just a couple of hours.
|
||||
It combines three ingredients:
|
||||
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
|
||||
2. **On-robot actor / learner loop with human interventions:** a distributed SAC/RLPD learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
|
||||
3. **Safety & efficiency tools:** joint/EE bounds, impedance control, crop-ROI preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
|
||||
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/hilserl-main-figure.png" alt="HIL-SERL workflow" title="HIL-SERL workflow" width="100%"></img>
|
||||
</p>
|
||||
|
||||
<p align="center"><i>HIL-SERL workflow, Luo et al. 2024</i></p>
|
||||
|
||||
This guide provides step-by-step instructions for training a robot policy using LeRobot's HilSerl implementation to train on a real robot.
|
||||
|
||||
|
||||
# 1. Real Robot Training Workflow
|
||||
|
||||
## 1.1 Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/common/envs/configs.py`. Which is defined as:
|
||||
|
||||
```python
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
robot: Optional[RobotConfig] = None # Main robot agent (defined in `lerobot/common/robots`)
|
||||
teleop: Optional[TeleoperatorConfig] = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/common/teleoperators`)
|
||||
wrapper: Optional[EnvTransformConfig] = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
|
||||
fps: int = 10 # Control frequency
|
||||
name: str = "real_robot" # Environment name
|
||||
mode: str = None # "record", "replay", or None (for training)
|
||||
repo_id: Optional[str] = None # LeRobot dataset repository ID
|
||||
dataset_root: Optional[str] = None # Local dataset root (optional)
|
||||
task: str = "" # Task identifier
|
||||
num_episodes: int = 10 # Number of episodes for recording
|
||||
episode: int = 0 # episode index for replay
|
||||
device: str = "cuda" # Compute device
|
||||
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
|
||||
pretrained_policy_name_or_path: Optional[str] = None # For policy loading
|
||||
reward_classifier_pretrained_path: Optional[str] = None # For reward model
|
||||
```
|
||||
|
||||
|
||||
## 1.2 Finding Robot Workspace Bounds
|
||||
|
||||
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
|
||||
|
||||
This helps simplifying the problem of learning on the real robot by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration.
|
||||
|
||||
### 1.2.1 Using find_joint_limits.py
|
||||
|
||||
This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training.
|
||||
Bounding the action space will reduce the redundant exploration of the agent and guarantees safety.
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.find_joint_limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
```
|
||||
|
||||
### 1.2.2 Workflow
|
||||
|
||||
1. Run the script and move the robot through the space that solves the task
|
||||
2. The script will record the minimum and maximum end-effector positions and the joint angles and prints them to the console, for example:
|
||||
```
|
||||
Max ee position [0.24170487 0.201285 0.10273342]
|
||||
Min ee position [0.16631757 -0.08237468 0.03364977]
|
||||
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
||||
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
||||
```
|
||||
3. Use these values in the configuration of you teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
|
||||
|
||||
### 1.2.3 Example Configuration
|
||||
|
||||
```json
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.20, 0.10],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
}
|
||||
```
|
||||
|
||||
## 1.3 Collecting Demonstrations
|
||||
|
||||
With the bounds defined, you can safely collect demonstrations for training. Training RL with off-policy algorithm allows us to use offline datasets collected in order to improve the efficiency of the learning process.
|
||||
|
||||
### 1.3.1 Setting Up Record Mode
|
||||
|
||||
Create a configuration file for recording demonstrations (or edit an existing one like `env_config_so100.json`):
|
||||
|
||||
1. Set `mode` to `"record"`
|
||||
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
|
||||
3. Set `num_episodes` to the number of demonstrations you want to collect
|
||||
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
|
||||
5. Configure `robot`, `cameras`, and other hardware settings
|
||||
|
||||
Example configuration section:
|
||||
```json
|
||||
"mode": "record",
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
```
|
||||
|
||||
### 1.3.2 Gamepad Controls
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/gamepad_guide.jpg?raw=true" alt="Figure shows the control mappings on a Logitech gamepad." title="Gamepad Control Mapping" width="100%"></img>
|
||||
</p>
|
||||
<p align="center"><i>Gamepad button mapping for robot control and episode management</i></p>
|
||||
|
||||
|
||||
### 1.3.3 Recording Demonstrations
|
||||
|
||||
Start the recording process:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config_so100.json
|
||||
```
|
||||
|
||||
During recording:
|
||||
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_position`
|
||||
2. Use the gamepad to control the robot by setting `"control_mode"="gamepad"` in the configuration file
|
||||
3. Complete the task successfully
|
||||
4. The episode ends with a reward of 1 when you press the "success" button
|
||||
5. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
|
||||
6. You can rerecord an episode by pressing the "rerecord" button
|
||||
7. The process automatically continues to the next episode
|
||||
8. After recording all episodes, the dataset is pushed to the Hugging Face Hub (optional) and saved locally
|
||||
|
||||
|
||||
|
||||
## 1.4 Processing the Dataset
|
||||
|
||||
After collecting demonstrations, process them to determine optimal camera crops.
|
||||
Reinforcement learning is sensitive to background distractions, so it is important to crop the images to the relevant workspace area.
|
||||
Note: If you already know the crop parameters, you can skip this step and just set the `crop_params_dict` in the configuration file during recording.
|
||||
|
||||
### 1.4.1 Determining Crop Parameters
|
||||
|
||||
Use the `crop_dataset_roi.py` script to interactively select regions of interest in your camera images:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/crop_dataset_roi.py --repo-id username/pick_lift_cube
|
||||
```
|
||||
|
||||
1. For each camera view, the script will display the first frame
|
||||
2. Draw a rectangle around the relevant workspace area
|
||||
3. Press 'c' to confirm the selection
|
||||
4. Repeat for all camera views
|
||||
5. The script outputs cropping parameters and creates a new cropped dataset
|
||||
|
||||
Example output:
|
||||
```
|
||||
Selected Rectangular Regions of Interest (top, left, height, width):
|
||||
observation.images.side: [180, 207, 180, 200]
|
||||
observation.images.front: [180, 250, 120, 150]
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/crop_dataset.gif" width="600"/>
|
||||
</p>
|
||||
|
||||
<p align="center"><i>Interactive cropping tool for selecting regions of interest</i></p>
|
||||
|
||||
|
||||
### 1.4.2 Updating Configuration
|
||||
|
||||
Add these crop parameters to your training configuration:
|
||||
|
||||
```json
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
```
|
||||
|
||||
## 1.5 Training with Actor-Learner
|
||||
|
||||
The LeRobot system uses a distributed actor-learner architecture for training. You will need to start two processes: a learner and an actor.
|
||||
|
||||
### 1.5.1 Configuration Setup
|
||||
|
||||
Create a training configuration file (See example `train_config_hilserl_so100.json`). The training config is based on the main `TrainPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
|
||||
1. Set `mode` to `null` (for training mode)
|
||||
2. Configure the policy settings (`type`, `device`, etc.)
|
||||
3. Set `dataset` to your cropped dataset
|
||||
4. Configure environment settings with crop parameters
|
||||
5. Check the other parameters related to SAC.
|
||||
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
### 1.5.2 Starting the Learner
|
||||
|
||||
First, start the learner server process:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
The learner:
|
||||
- Initializes the policy network
|
||||
- Prepares replay buffers
|
||||
- Opens a gRPC server to communicate with actors
|
||||
- Processes transitions and updates the policy
|
||||
|
||||
### 1.5.3 Starting the Actor
|
||||
|
||||
In a separate terminal, start the actor process with the same configuration:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
The actor:
|
||||
- Connects to the learner via gRPC
|
||||
- Initializes the environment
|
||||
- Execute rollouts of the policy to collect experience
|
||||
- Sends transitions to the learner
|
||||
- Receives updated policy parameters
|
||||
|
||||
### 1.5.4 Training Flow
|
||||
|
||||
The training proceeds automatically:
|
||||
|
||||
1. The actor executes the policy in the environment
|
||||
2. Transitions are collected and sent to the learner
|
||||
3. The learner updates the policy based on these transitions
|
||||
4. Updated policy parameters are sent back to the actor
|
||||
5. The process continues until the specified step limit is reached
|
||||
|
||||
### 1.5.5 Human in the Loop
|
||||
|
||||
- The key to learning efficiently is to have a human interventions to provide corrective feedback and completing the task to aide the policy learning and exploration.
|
||||
- To perform human interventions, you can press the upper right trigger button on the gamepad. This will pause the policy actions and allow you to take over.
|
||||
- A successful experiment is one where the human has to intervene at the start but then reduces the amount of interventions as the policy improves. You can monitor the intervention rate in the `wandb` dashboard.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/hil_effect.png?raw=true" alt="Figure shows the control mappings on a Logitech gamepad." title="Gamepad Control Mapping" width="100%"></img>
|
||||
</p>
|
||||
|
||||
<p align="center"><i>Example showing how human interventions help guide policy learning over time</i></p>
|
||||
|
||||
- The figure shows the plot of the episodic reward over interaction step. The figure shows the effect of human interventions on the policy learning.
|
||||
- The orange curve is an experiment without any human interventions. While the pink and blue curves are experiments with human interventions.
|
||||
- We can observe that the number of steps where the policy starts achieving the maximum reward is cut by a quarter when human interventions are present.
|
||||
|
||||
#### Guide to Human Interventions
|
||||
The strategy to follow is to intervene heavily at the start of training and then reduce the amount of interventions as the training progresses. Some tips and hints:
|
||||
- Interevene for almost the length of the entire episode at the first few episodes.
|
||||
- When the policy is less chaotic, gradually reduce the intervention time during one episode and let the policy explore for a longer time.
|
||||
- Once the policy start guiding the robot towards achieving the task, even if its not perfect, you can limit your interventions to simple quick actions like a grasping command, or grasp and lift command.
|
||||
|
||||
## 1.6 Monitoring and Debugging
|
||||
|
||||
If you have `wandb.enable` set to `true` in your configuration, you can monitor training progress in real-time through the [Weights & Biases](https://wandb.ai/site/) dashboard.
|
||||
|
||||
# 2. Training a Reward Classifier with LeRobot
|
||||
|
||||
This guide explains how to train a reward classifier for human-in-the-loop reinforcement learning implementation of LeRobot. Reward classifiers learn to predict the reward value given a state which can be used in an RL setup to train a policy.
|
||||
|
||||
|
||||
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
|
||||
## 2.1 Collecting a Dataset
|
||||
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
|
||||
To collect a dataset, you need to modeify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
### 2.1.1 Key Parameters for Data Collection:
|
||||
|
||||
- **mode**: set it to "record" to collect a dataset
|
||||
- **repo_id**: "hf_username/dataset_name", name of the dataset and repo on the hub
|
||||
- **num_episodes**: Number of episodes to record
|
||||
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
|
||||
- **fps**: Number of frames per second to record
|
||||
- **push_to_hub**: Whether to push the dataset to the hub
|
||||
|
||||
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
|
||||
|
||||
Example configuration section for data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"mode": "record",
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"num_episodes": 20,
|
||||
"push_to_hub": true,
|
||||
"fps": 10,
|
||||
"number_of_steps_after_success": 15
|
||||
}
|
||||
```
|
||||
|
||||
## 2.2 Reward Classifier Configuration
|
||||
|
||||
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
|
||||
|
||||
- **model_name**: Base model architecture (e.g., we mainly use "helper2424/resnet10")
|
||||
- **model_type**: "cnn" or "transformer"
|
||||
- **num_cameras**: Number of camera inputs
|
||||
- **num_classes**: Number of output classes (typically 2 for binary success/failure)
|
||||
- **hidden_dim**: Size of hidden representation
|
||||
- **dropout_rate**: Regularization parameter
|
||||
- **learning_rate**: Learning rate for optimizer
|
||||
|
||||
Example configuration from `reward_classifier_train_config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"policy": {
|
||||
"type": "reward_classifier",
|
||||
"model_name": "helper2424/resnet10",
|
||||
"model_type": "cnn",
|
||||
"num_cameras": 2,
|
||||
"num_classes": 2,
|
||||
"hidden_dim": 256,
|
||||
"dropout_rate": 0.1,
|
||||
"learning_rate": 1e-4,
|
||||
"device": "cuda",
|
||||
"use_amp": true,
|
||||
"input_features": {
|
||||
"observation.images.front": {
|
||||
"type": "VISUAL",
|
||||
"shape": [3, 128, 128]
|
||||
},
|
||||
"observation.images.side": {
|
||||
"type": "VISUAL",
|
||||
"shape": [3, 128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2.3 Training the Classifier
|
||||
|
||||
To train the classifier, use the `train.py` script with your configuration:
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
## 2.4 Deploying and Testing the Model
|
||||
|
||||
To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to use your model:
|
||||
|
||||
```python
|
||||
env_config = HILSerlRobotEnvConfig(
|
||||
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
|
||||
# Other environment parameters
|
||||
)
|
||||
```
|
||||
or set the argument in the json config file.
|
||||
|
||||
```json
|
||||
{
|
||||
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
|
||||
}
|
||||
```
|
||||
|
||||
Run gym_manipulator.py to test the model.
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json
|
||||
```
|
||||
|
||||
The reward classifier will automatically provide rewards based on the visual input from the robot's cameras.
|
||||
|
||||
## 2.5 Example Workflow
|
||||
|
||||
1. **Create the configuration files**:
|
||||
Create the necessary json configuration files for the reward classifier and the environment. Check the `json_examples` directory for examples.
|
||||
|
||||
2. **Collect a dataset**:
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json
|
||||
```
|
||||
|
||||
3. **Train the classifier**:
|
||||
```bash
|
||||
python lerobot/scripts/train.py --config_path lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
4. **Test the classifier**:
|
||||
```bash
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json
|
||||
```
|
||||
# 3. Using gym_hil Simulation Environments with LeRobot
|
||||
|
||||
This guide explains how to use the `gym_hil` simulation environments as an alternative to real robots when working with the LeRobot framework for Human-In-the-Loop (HIL) reinforcement learning.
|
||||
|
||||
`gym_hil` is a package that provides Gymnasium-compatible simulation environments specifically designed for Human-In-the-Loop reinforcement learning. These environments allow you to:
|
||||
|
||||
- Train policies in simulation to test the RL stack before training on real robots
|
||||
|
||||
- Collect demonstrations in sim using external devices like gamepads or keyboards
|
||||
- Perform human interventions during policy learning
|
||||
|
||||
Currently, the main environment is a Franka Panda robot simulation based on MuJoCo, with tasks like picking up a cube.
|
||||
|
||||
## 3.1 Installation
|
||||
|
||||
First, install the `gym_hil` package within the LeRobot environment:
|
||||
|
||||
```bash
|
||||
pip install gym_hil
|
||||
|
||||
# Or in LeRobot
|
||||
cd lerobot
|
||||
pip install -e .[hilserl]
|
||||
```
|
||||
|
||||
## 3.2 Configuration
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided in `gym_hil_env.json`. Key configuration sections include:
|
||||
|
||||
### 3.2.1 Environment Type and Task
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hil",
|
||||
"name": "franka_sim",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Available tasks:
|
||||
- `PandaPickCubeBase-v0`: Basic environment
|
||||
- `PandaPickCubeGamepad-v0`: With gamepad control
|
||||
- `PandaPickCubeKeyboard-v0`: With keyboard control
|
||||
|
||||
### 3.2.2 Gym Wrappers Configuration
|
||||
|
||||
```json
|
||||
"wrapper": {
|
||||
"gripper_penalty": -0.02,
|
||||
"control_time_s": 15.0,
|
||||
"use_gripper": true,
|
||||
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
},
|
||||
"control_mode": "gamepad"
|
||||
}
|
||||
```
|
||||
|
||||
Important parameters:
|
||||
- `gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `use_gripper`: Whether to enable gripper control
|
||||
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `control_mode`: Set to "gamepad" to use a gamepad controller
|
||||
|
||||
## 3.3 Running with HIL RL of LeRobot
|
||||
|
||||
### 3.3.1 Basic Usage
|
||||
|
||||
To run the environment, set mode to null:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
|
||||
### 3.3.2 Recording a Dataset
|
||||
|
||||
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
|
||||
### 3.3.3 Training a Policy
|
||||
|
||||
To train a policy, checkout the example json in `train_gym_hil_env.json` and run the actor and learner servers:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/actor.py --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
|
||||
In a different terminal, run the learner server:
|
||||
|
||||
```python
|
||||
python lerobot/scripts/rl/learner.py --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
|
||||
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
|
||||
|
||||
Paper citation:
|
||||
|
||||
```
|
||||
@article{luo2024precise,
|
||||
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
|
||||
author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2410.21845},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
@@ -1,998 +0,0 @@
|
||||
# Getting Started with Real-World Robots
|
||||
|
||||
This tutorial will guide you through the process of setting up and training a neural network to autonomously control a real robot.
|
||||
|
||||
**What You'll Learn:**
|
||||
1. How to order and assemble your robot.
|
||||
2. How to connect, configure, and calibrate your robot.
|
||||
3. How to record and visualize your dataset.
|
||||
4. How to train a policy using your data and prepare it for evaluation.
|
||||
5. How to evaluate your policy and visualize the results.
|
||||
|
||||
By following these steps, you'll be able to replicate tasks like picking up a Lego block and placing it in a bin with a high success rate, as demonstrated in [this video](https://x.com/RemiCadene/status/1814680760592572934).
|
||||
|
||||
This tutorial is specifically made for the affordable [Koch v1.1](https://github.com/jess-moss/koch-v1-1) robot, but it contains additional information to be easily adapted to various types of robots like [Aloha bimanual robot](https://aloha-2.github.io) by changing some configurations. The Koch v1.1 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot.
|
||||
|
||||
During the data collection phase, you will control the follower arm by moving the leader arm. This process is known as "teleoperation." This technique is used to collect robot trajectories. Afterward, you'll train a neural network to imitate these trajectories and deploy the network to enable your robot to operate autonomously.
|
||||
|
||||
If you encounter any issues at any step of the tutorial, feel free to seek help on [Discord](https://discord.com/invite/s3KuuzsPFb) or don't hesitate to iterate with us on the tutorial by creating issues or pull requests. Thanks!
|
||||
|
||||
## 1. Order and Assemble your Koch v1.1
|
||||
|
||||
Follow the sourcing and assembling instructions provided on the [Koch v1.1 Github page](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below.
|
||||
|
||||
<div style="text-align:center;">
|
||||
<img src="../media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
|
||||
</div>
|
||||
|
||||
For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk).
|
||||
|
||||
## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1
|
||||
|
||||
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands (make sure gcc is installed).
|
||||
|
||||
Using `pip`:
|
||||
```bash
|
||||
pip install -e ".[dynamixel]"
|
||||
```
|
||||
|
||||
Using `poetry`:
|
||||
```bash
|
||||
poetry sync --extras "dynamixel"
|
||||
```
|
||||
|
||||
Using `uv`:
|
||||
```bash
|
||||
uv sync --extra "dynamixel"
|
||||
```
|
||||
|
||||
You are now ready to plug the 5V power supply to the motor bus of the leader arm (the smaller one) since all its motors only require 5V.
|
||||
|
||||
Then plug the 12V power supply to the motor bus of the follower arm. It has two motors that need 12V, and the rest will be powered with 5V through the voltage convertor.
|
||||
|
||||
Finally, connect both arms to your computer via USB. Note that the USB doesn't provide any power, and both arms need to be plugged in with their associated power supply to be detected by your computer.
|
||||
|
||||
Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file.
|
||||
|
||||
If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial):
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
It will automatically:
|
||||
1. Identify any missing calibrations and initiate the calibration procedure.
|
||||
2. Connect the robot and start teleoperation.
|
||||
|
||||
### a. Control your motors with DynamixelMotorsBus
|
||||
|
||||
You can use the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) to communicate with the motors connected as a chain to the corresponding USB bus. This class leverages the Python [Dynamixel SDK](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20) to facilitate reading from and writing to the motors.
|
||||
|
||||
**First Configuration of your motors**
|
||||
|
||||
You will need to unplug each motor in turn and run a command the identify the motor. The motor will save its own identification, so you only need to do this once. Start by unplugging all of the motors.
|
||||
|
||||
Do the Leader arm first, as all of its motors are of the same type. Plug in your first motor on your leader arm and run this script to set its ID to 1.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--id 1
|
||||
```
|
||||
|
||||
Then unplug your first motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--id 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6.
|
||||
|
||||
The process for the follower arm is almost the same, but the follower arm has two types of motors. For the first two motors, make sure you set the model to `xl430-w250`. _Important: configuring follower motors requires plugging and unplugging power. Make sure you use the 5V power for the XL330s and the 12V power for the XL430s!_
|
||||
|
||||
After all of your motors are configured properly, you're ready to plug them all together in a daisy-chain as shown in the original video.
|
||||
|
||||
**Instantiate the DynamixelMotorsBus**
|
||||
|
||||
To begin, create two instances of the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py), one for each arm, using their corresponding USB ports (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`).
|
||||
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running this command with your ports:
|
||||
```bash
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0032081
|
||||
sudo chmod 666 /dev/tty.usbmodem575E0031751
|
||||
```
|
||||
|
||||
*Listing and Configuring Motors*
|
||||
|
||||
Next, you'll need to list the motors for each arm, including their name, index, and model. Initially, each motor is assigned the factory default index `1`. Since each motor requires a unique index to function correctly when connected in a chain on a common bus, you'll need to assign different indices. It's recommended to use an ascending index order, starting from `1` (e.g., `1, 2, 3, 4, 5, 6`). These indices will be saved in the persistent memory of each motor during the first connection.
|
||||
|
||||
To assign indices to the motors, run this code in an interactive Python session. Replace the `port` values with the ones you identified earlier:
|
||||
```python
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
leader_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
"shoulder_lift": (2, "xl330-m077"),
|
||||
"elbow_flex": (3, "xl330-m077"),
|
||||
"wrist_flex": (4, "xl330-m077"),
|
||||
"wrist_roll": (5, "xl330-m077"),
|
||||
"gripper": (6, "xl330-m077"),
|
||||
},
|
||||
)
|
||||
|
||||
follower_config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
)
|
||||
|
||||
leader_arm = DynamixelMotorsBus(leader_config)
|
||||
follower_arm = DynamixelMotorsBus(follower_config)
|
||||
```
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update [`KochRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```python
|
||||
@RobotConfig.register_subclass("koch")
|
||||
@dataclass
|
||||
class KochRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Connect and Configure your Motors**
|
||||
|
||||
Before you can start using your motors, you'll need to configure them to ensure proper communication. When you first connect the motors, the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) automatically detects any mismatch between the current motor indices (factory set to `1`) and the specified indices (e.g., `1, 2, 3, 4, 5, 6`). This triggers a configuration procedure that requires you to unplug the power cord and motors, then reconnect each motor sequentially, starting from the one closest to the bus.
|
||||
|
||||
For a visual guide, refer to the [video tutorial of the configuration procedure](https://youtu.be/U78QQ9wCdpY).
|
||||
|
||||
To connect and configure the leader arm, run the following code in the same Python interactive session as earlier in the tutorial:
|
||||
```python
|
||||
leader_arm.connect()
|
||||
```
|
||||
|
||||
When you connect the leader arm for the first time, you might see an output similar to this:
|
||||
```
|
||||
Read failed due to communication error on port /dev/tty.usbmodem575E0032081 for group_key ID_shoulder_pan_shoulder_lift_elbow_flex_wrist_flex_wrist_roll_gripper: [TxRxResult] There is no status packet!
|
||||
|
||||
/!\ A configuration issue has been detected with your motors:
|
||||
If this is the first time you are using these motors, press enter to configure your motors... but before verify that all the cables are connected the proper way. If you find an issue, before making a modification, kill the python process, unplug the power cord to not damage the motors, rewire correctly, then plug the power again and relaunch the script.
|
||||
|
||||
Motor indices detected: {9600: [1]}
|
||||
|
||||
1. Unplug the power cord
|
||||
2. Plug/unplug minimal number of cables to only have the first 1 motor(s) (['shoulder_pan']) connected.
|
||||
3. Re-plug the power cord
|
||||
Press Enter to continue...
|
||||
|
||||
*Follow the procedure*
|
||||
|
||||
Setting expected motor indices: [1, 2, 3, 4, 5, 6]
|
||||
```
|
||||
|
||||
Once the leader arm is configured, repeat the process for the follower arm by running:
|
||||
```python
|
||||
follower_arm.connect()
|
||||
```
|
||||
|
||||
Congratulations! Both arms are now properly configured and connected. You won't need to go through the configuration procedure again in the future.
|
||||
|
||||
**Troubleshooting**:
|
||||
|
||||
If the configuration process fails, you may need to do the configuration process via the Dynamixel Wizard.
|
||||
|
||||
Known failure modes:
|
||||
- Calling `arm.connect()` raises `OSError: No motor found, but one new motor expected. Verify power cord is plugged in and retry` on Ubuntu 22.
|
||||
|
||||
Steps:
|
||||
1. Visit https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2/#connect-dynamixel.
|
||||
2. Follow the software installation instructions in section 3 of the web page.
|
||||
3. Launch the software.
|
||||
4. Configure the device scanning options in the menu under `Tools` > `Options` > `Scan`. Check only Protocol 2.0, select only the USB port identifier of interest, select all baudrates, set the ID range to `[0, 10]`. _While this step was not strictly necessary, it greatly speeds up scanning_.
|
||||
5. For each motor in turn:
|
||||
- Disconnect the power to the driver board.
|
||||
- Connect **only** the motor of interest to the driver board, making sure to disconnect it from any other motors.
|
||||
- Reconnect the power to the driver board.
|
||||
- From the software menu select `Device` > `Scan` and let the scan run. A device should appear.
|
||||
- If the device has an asterisk (*) near it, it means the firmware is indeed outdated. From the software menu, select `Tools` > `Firmware Update`. Follow the prompts.
|
||||
- The main panel should have table with various parameters of the device (refer to the web page, section 5). Select the row with `ID`, and then set the desired ID on the bottom right panel by selecting and clicking `Save`.
|
||||
- Just like you did with the ID, also set the `Baud Rate` to 1 Mbps.
|
||||
6. Check everything has been done right:
|
||||
- Rewire the arms in their final configuration and power both of them.
|
||||
- Scan for devices. All 12 motors should appear.
|
||||
- Select the motors one by one and move the arm. Check that the graphical indicator near the top right shows the movement.
|
||||
|
||||
** There is a common issue with the Dynamixel XL430-W250 motors where the motors become undiscoverable after upgrading their firmware from Mac and Windows Dynamixel Wizard2 applications. When this occurs, it is required to do a firmware recovery (Select `DYNAMIXEL Firmware Recovery` and follow the prompts). There are two known workarounds to conduct this firmware reset:
|
||||
1) Install the Dynamixel Wizard on a linux machine and complete the firmware recovery
|
||||
2) Use the Dynamixel U2D2 in order to perform the reset with Windows or Mac. This U2D2 can be purchased [here](https://www.robotis.us/u2d2/).
|
||||
For either solution, open DYNAMIXEL Wizard 2.0 and select the appropriate port. You will likely be unable to see the motor in the GUI at this time. Select `Firmware Recovery`, carefully choose the correct model, and wait for the process to complete. Finally, re-scan to confirm the firmware recovery was successful.
|
||||
|
||||
**Read and Write with DynamixelMotorsBus**
|
||||
|
||||
To get familiar with how `DynamixelMotorsBus` communicates with the motors, you can start by reading data from them. Copy past this code in the same interactive python session:
|
||||
```python
|
||||
leader_pos = leader_arm.read("Present_Position")
|
||||
follower_pos = follower_arm.read("Present_Position")
|
||||
print(leader_pos)
|
||||
print(follower_pos)
|
||||
```
|
||||
|
||||
Expected output might look like:
|
||||
```
|
||||
array([2054, 523, 3071, 1831, 3049, 2441], dtype=int32)
|
||||
array([2003, 1601, 56, 2152, 3101, 2283], dtype=int32)
|
||||
```
|
||||
|
||||
Try moving the arms to various positions and observe how the values change.
|
||||
|
||||
Now let's try to enable torque in the follower arm by copy pasting this code:
|
||||
```python
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
|
||||
follower_arm.write("Torque_Enable", TorqueMode.ENABLED.value)
|
||||
```
|
||||
|
||||
With torque enabled, the follower arm will be locked in its current position. Do not attempt to manually move the arm while torque is enabled, as this could damage the motors.
|
||||
|
||||
Now, to get more familiar with reading and writing, let's move the arm programmatically copy pasting the following example code:
|
||||
```python
|
||||
# Get the current position
|
||||
position = follower_arm.read("Present_Position")
|
||||
|
||||
# Update first motor (shoulder_pan) position by +10 steps
|
||||
position[0] += 10
|
||||
follower_arm.write("Goal_Position", position)
|
||||
|
||||
# Update all motors position by -30 steps
|
||||
position -= 30
|
||||
follower_arm.write("Goal_Position", position)
|
||||
|
||||
# Update gripper by +30 steps
|
||||
position[-1] += 30
|
||||
follower_arm.write("Goal_Position", position[-1], "gripper")
|
||||
```
|
||||
|
||||
When you're done playing, you can try to disable the torque, but make sure you hold your robot so that it doesn't fall:
|
||||
```python
|
||||
follower_arm.write("Torque_Enable", TorqueMode.DISABLED.value)
|
||||
```
|
||||
|
||||
Finally, disconnect the arms:
|
||||
```python
|
||||
leader_arm.disconnect()
|
||||
follower_arm.disconnect()
|
||||
```
|
||||
|
||||
Alternatively, you can unplug the power cord, which will automatically disable torque and disconnect the motors.
|
||||
|
||||
*/!\ Warning*: These motors tend to overheat, especially under torque or if left plugged in for too long. Unplug after use.
|
||||
|
||||
### b. Teleoperate your Koch v1.1 with ManipulatorRobot
|
||||
|
||||
**Instantiate the ManipulatorRobot**
|
||||
|
||||
Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_config` and `follower_config`.
|
||||
|
||||
For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_config}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_config, "right": right_leader_config},`. Same thing for the follower arms.
|
||||
|
||||
|
||||
Run the following code to instantiate your manipulator robot:
|
||||
```python
|
||||
from lerobot.common.robot_devices.robots.configs import KochRobotConfig
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
|
||||
robot_config = KochRobotConfig(
|
||||
leader_arms={"main": leader_config},
|
||||
follower_arms={"main": follower_config},
|
||||
cameras={}, # We don't use any camera for now
|
||||
)
|
||||
robot = ManipulatorRobot(robot_config)
|
||||
```
|
||||
|
||||
The `KochRobotConfig` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
|
||||
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha.
|
||||
|
||||
**Calibrate and Connect the ManipulatorRobot**
|
||||
|
||||
Next, you'll need to calibrate your Koch robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Koch robot to work on another.
|
||||
|
||||
When you connect your robot for the first time, the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) will detect if the calibration file is missing and trigger the calibration procedure. During this process, you will be guided to move each arm to three different positions.
|
||||
|
||||
Here are the positions you'll move the follower arm to:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
And here are the corresponding positions for the leader arm:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask you to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||
|
||||
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
||||
|
||||
Importantly, once calibrated, all Koch robots will move to the same positions (e.g. zero and rotated position) when commanded.
|
||||
|
||||
Run the following code to calibrate and connect your robot:
|
||||
```python
|
||||
robot.connect()
|
||||
```
|
||||
|
||||
The output will look like this:
|
||||
```
|
||||
Connecting main follower arm
|
||||
Connecting main leader arm
|
||||
|
||||
Missing calibration file '.cache/calibration/koch/main_follower.json'
|
||||
Running calibration of koch main follower...
|
||||
Move arm to zero position
|
||||
[...]
|
||||
Move arm to rotated position
|
||||
[...]
|
||||
Move arm to rest position
|
||||
[...]
|
||||
Calibration is done! Saving calibration file '.cache/calibration/koch/main_follower.json'
|
||||
|
||||
Missing calibration file '.cache/calibration/koch/main_leader.json'
|
||||
Running calibration of koch main leader...
|
||||
Move arm to zero position
|
||||
[...]
|
||||
Move arm to rotated position
|
||||
[...]
|
||||
Move arm to rest position
|
||||
[...]
|
||||
Calibration is done! Saving calibration file '.cache/calibration/koch/main_leader.json'
|
||||
```
|
||||
|
||||
*Verifying Calibration*
|
||||
|
||||
Once calibration is complete, you can check the positions of the leader and follower arms to ensure they match. If the calibration was successful, the positions should be very similar.
|
||||
|
||||
Run this code to get the positions in degrees:
|
||||
```python
|
||||
leader_pos = robot.leader_arms["main"].read("Present_Position")
|
||||
follower_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
print(leader_pos)
|
||||
print(follower_pos)
|
||||
```
|
||||
|
||||
Example output:
|
||||
```
|
||||
array([-0.43945312, 133.94531, 179.82422, -18.984375, -1.9335938, 34.541016], dtype=float32)
|
||||
array([-0.58723712, 131.72314, 174.98743, -16.872612, 0.786213, 35.271973], dtype=float32)
|
||||
```
|
||||
|
||||
These values are in degrees, which makes them easier to interpret and debug. The zero position used during calibration should roughly correspond to 0 degrees for each motor, and the rotated position should roughly correspond to 90 degrees for each motor.
|
||||
|
||||
**Teleoperate your Koch v1.1**
|
||||
|
||||
You can easily teleoperate your robot by reading the positions from the leader arm and sending them as goal positions to the follower arm.
|
||||
|
||||
To teleoperate your robot for 30 seconds at a frequency of approximately 200Hz, run the following code:
|
||||
```python
|
||||
import tqdm
|
||||
seconds = 30
|
||||
frequency = 200
|
||||
for _ in tqdm.tqdm(range(seconds*frequency)):
|
||||
leader_pos = robot.leader_arms["main"].read("Present_Position")
|
||||
robot.follower_arms["main"].write("Goal_Position", leader_pos)
|
||||
```
|
||||
|
||||
*Using `teleop_step` for Teleoperation*
|
||||
|
||||
Alternatively, you can teleoperate the robot using the `teleop_step` method from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py).
|
||||
|
||||
Run this code to teleoperate:
|
||||
```python
|
||||
for _ in tqdm.tqdm(range(seconds*frequency)):
|
||||
robot.teleop_step()
|
||||
```
|
||||
|
||||
*Recording data during Teleoperation*
|
||||
|
||||
Teleoperation is particularly useful for recording data. You can use the `teleop_step(record_data=True)` to returns both the follower arm's position as `"observation.state"` and the leader arm's position as `"action"`. This function also converts the numpy arrays into PyTorch tensors. If you're working with a robot that has two leader and two follower arms (like the Aloha), the positions are concatenated.
|
||||
|
||||
Run the following code to see how slowly moving the leader arm affects the observation and action:
|
||||
```python
|
||||
leader_pos = robot.leader_arms["main"].read("Present_Position")
|
||||
follower_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
print(follower_pos)
|
||||
print(observation)
|
||||
print(leader_pos)
|
||||
print(action)
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
array([7.8223, 131.1328, 165.5859, -23.4668, -0.9668, 32.4316], dtype=float32)
|
||||
{'observation.state': tensor([7.8223, 131.1328, 165.5859, -23.4668, -0.9668, 32.4316])}
|
||||
array([3.4277, 134.1211, 179.8242, -18.5449, -1.5820, 34.7168], dtype=float32)
|
||||
{'action': tensor([3.4277, 134.1211, 179.8242, -18.5449, -1.5820, 34.7168])}
|
||||
```
|
||||
|
||||
*Asynchronous Frame Recording*
|
||||
|
||||
Additionally, `teleop_step` can asynchronously record frames from multiple cameras and include them in the observation dictionary as `"observation.images.CAMERA_NAME"`. This feature will be covered in more detail in the next section.
|
||||
|
||||
*Disconnecting the Robot*
|
||||
|
||||
When you're finished, make sure to disconnect your robot by running:
|
||||
```python
|
||||
robot.disconnect()
|
||||
```
|
||||
|
||||
Alternatively, you can unplug the power cord, which will also disable torque.
|
||||
|
||||
*/!\ Warning*: These motors tend to overheat, especially under torque or if left plugged in for too long. Unplug after use.
|
||||
|
||||
### c. Add your cameras with OpenCVCamera
|
||||
|
||||
**(Optional) Use your phone as camera on Linux**
|
||||
|
||||
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
|
||||
|
||||
1. *Install `v4l2loopback-dkms` and `v4l-utils`*. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
|
||||
```python
|
||||
sudo apt install v4l2loopback-dkms v4l-utils
|
||||
```
|
||||
2. *Install [DroidCam](https://droidcam.app) on your phone*. This app is available for both iOS and Android.
|
||||
3. *Install [OBS Studio](https://obsproject.com)*. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio
|
||||
```
|
||||
4. *Install the DroidCam OBS plugin*. This plugin integrates DroidCam with OBS Studio. Install it with:
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
|
||||
```
|
||||
5. *Start OBS Studio*. Launch with:
|
||||
```python
|
||||
flatpak run com.obsproject.Studio
|
||||
```
|
||||
6. *Add your phone as a source*. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
|
||||
7. *Adjust resolution settings*. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
|
||||
8. *Start virtual camera*. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
|
||||
9. *Verify the virtual camera setup*. Use `v4l2-ctl` to list the devices:
|
||||
```python
|
||||
v4l2-ctl --list-devices
|
||||
```
|
||||
You should see an entry like:
|
||||
```
|
||||
VirtualCam (platform:v4l2loopback-000):
|
||||
/dev/video1
|
||||
```
|
||||
10. *Check the camera resolution*. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
|
||||
```python
|
||||
v4l2-ctl -d /dev/video1 --get-fmt-video
|
||||
```
|
||||
You should see an entry like:
|
||||
```
|
||||
>>> Format Video Capture:
|
||||
>>> Width/Height : 640/480
|
||||
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
|
||||
```
|
||||
|
||||
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
|
||||
|
||||
If everything is set up correctly, you can proceed with the rest of the tutorial.
|
||||
|
||||
**(Optional) Use your iPhone as a camera on MacOS**
|
||||
|
||||
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
- Ensure your Mac is running macOS 13 or later, and your iPhone is on iOS 16 or later.
|
||||
- Sign in both devices with the same Apple ID.
|
||||
- Connect your devices with a USB cable or turn on Wi-Fi and Bluetooth for a wireless connection.
|
||||
|
||||
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
|
||||
|
||||
Your iPhone should be detected automatically when running the camera setup script in the next section.
|
||||
|
||||
**Instantiate an OpenCVCamera**
|
||||
|
||||
The [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py) class allows you to efficiently record frames from most cameras using the [`opencv2`](https://docs.opencv.org) library. For more details on compatibility, see [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
To instantiate an [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py), you need a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera like a webcam of a laptop, the camera index is usually `0` but it might differ, and the camera index might change if you reboot your computer or re-plug your camera. This behavior depends on your operating system.
|
||||
|
||||
To find the camera indices, run the following utility script, which will save a few frames from each detected camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/opencv.py \
|
||||
--images-dir outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
```
|
||||
Mac or Windows detected. Finding available camera indices through scanning all indices from 0 to 60
|
||||
[...]
|
||||
Camera found at index 0
|
||||
Camera found at index 1
|
||||
[...]
|
||||
Connecting cameras
|
||||
OpenCVCamera(0, fps=30.0, width=1920.0, height=1080.0, color_mode=rgb)
|
||||
OpenCVCamera(1, fps=24.0, width=1920.0, height=1080.0, color_mode=rgb)
|
||||
Saving images to outputs/images_from_opencv_cameras
|
||||
Frame: 0000 Latency (ms): 39.52
|
||||
[...]
|
||||
Frame: 0046 Latency (ms): 40.07
|
||||
Images have been saved to outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
Check the saved images in `outputs/images_from_opencv_cameras` to identify which camera index corresponds to which physical camera (e.g. `0` for `camera_00` or `1` for `camera_01`):
|
||||
```
|
||||
camera_00_frame_000000.png
|
||||
[...]
|
||||
camera_00_frame_000047.png
|
||||
camera_01_frame_000000.png
|
||||
[...]
|
||||
camera_01_frame_000047.png
|
||||
```
|
||||
|
||||
Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green.
|
||||
|
||||
Finally, run this code to instantiate and connect your camera:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
config = OpenCVCameraConfig(camera_index=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
|
||||
print(color_image.shape)
|
||||
print(color_image.dtype)
|
||||
```
|
||||
|
||||
Expected output for a laptop camera on MacBookPro:
|
||||
```
|
||||
(1080, 1920, 3)
|
||||
uint8
|
||||
```
|
||||
|
||||
Or like this if you followed our tutorial to set a virtual camera:
|
||||
```
|
||||
(480, 640, 3)
|
||||
uint8
|
||||
```
|
||||
|
||||
With certain camera, you can also specify additional parameters like frame rate, resolution, and color mode during instantiation. For instance:
|
||||
```python
|
||||
config = OpenCVCameraConfig(camera_index=0, fps=30, width=640, height=480)
|
||||
```
|
||||
|
||||
If the provided arguments are not compatible with the camera, an exception will be raised.
|
||||
|
||||
*Disconnecting the camera*
|
||||
|
||||
When you're done using the camera, disconnect it by running:
|
||||
```python
|
||||
camera.disconnect()
|
||||
```
|
||||
|
||||
**Instantiate your robot with cameras**
|
||||
|
||||
Additionally, you can set up your robot to work with your cameras.
|
||||
|
||||
Modify the following Python code with the appropriate camera names and configurations:
|
||||
```python
|
||||
robot = ManipulatorRobot(
|
||||
KochRobotConfig(
|
||||
leader_arms={"main": leader_arm},
|
||||
follower_arms={"main": follower_arm},
|
||||
calibration_dir=".cache/calibration/koch",
|
||||
cameras={
|
||||
"laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCameraConfig(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
)
|
||||
robot.connect()
|
||||
```
|
||||
|
||||
As a result, `teleop_step(record_data=True` will return a frame for each camera following the pytorch "channel first" convention but we keep images in `uint8` with pixels in range [0,255] to easily save them.
|
||||
|
||||
Modify this code with the names of your cameras and run it:
|
||||
```python
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
print(observation["observation.images.laptop"].shape)
|
||||
print(observation["observation.images.phone"].shape)
|
||||
print(observation["observation.images.laptop"].min().item())
|
||||
print(observation["observation.images.laptop"].max().item())
|
||||
```
|
||||
|
||||
The output should look like this:
|
||||
```
|
||||
torch.Size([3, 480, 640])
|
||||
torch.Size([3, 480, 640])
|
||||
0
|
||||
255
|
||||
```
|
||||
|
||||
### d. Use `control_robot.py` and our `teleoperate` function
|
||||
|
||||
Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the robot configurations via command line and control your robot with various modes as explained next.
|
||||
|
||||
Try running this code to teleoperate your robot (if you dont have a camera, keep reading):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz)
|
||||
```
|
||||
|
||||
It contains
|
||||
- `2024-08-10 11:15:03` which is the date and time of the call to the print function.
|
||||
- `ol_robot.py:209` which is the end of the file name and the line number where the print function is called (`lerobot/scripts/control_robot.py` line `209`).
|
||||
- `dt: 5.12 (195.1hz)` which is the "delta time" or the number of milliseconds spent between the previous call to `robot.teleop_step()` and the current one, associated with the frequency (5.12 ms equals 195.1 Hz) ; note that you can control the maximum frequency by adding fps as argument such as `--fps 30`.
|
||||
- `dtRlead: 4.93 (203.0hz)` which is the number of milliseconds it took to read the position of the leader arm using `leader_arm.read("Present_Position")`.
|
||||
- `dtWfoll: 0.22 (4446.9hz)` which is the number of milliseconds it took to set a new goal position for the follower arm using `follower_arm.write("Goal_position", leader_pos)` ; note that writing is done asynchronously so it takes less time than reading.
|
||||
|
||||
Importantly: If you don't have any camera, you can remove them dynamically with this [draccus](https://github.com/dlwh/draccus) syntax `--robot.cameras='{}'`:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
We advise to create a new yaml file when the command becomes too long.
|
||||
|
||||
## 3. Record your Dataset and Visualize it
|
||||
|
||||
Using what you've learned previously, you can now easily record a dataset of states and actions for one episode. You can use `busy_wait` to control the speed of teleoperation and record at a fixed `fps` (frame per seconds).
|
||||
|
||||
Try this code to record 30 seconds at 60 fps:
|
||||
```python
|
||||
import time
|
||||
from lerobot.scripts.control_robot import busy_wait
|
||||
|
||||
record_time_s = 30
|
||||
fps = 60
|
||||
|
||||
states = []
|
||||
actions = []
|
||||
for _ in range(record_time_s * fps):
|
||||
start_time = time.perf_counter()
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
states.append(observation["observation.state"])
|
||||
actions.append(action["action"])
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
# Note that observation and action are available in RAM, but
|
||||
# you could potentially store them on disk with pickle/hdf5 or
|
||||
# our optimized format `LeRobotDataset`. More on this next.
|
||||
```
|
||||
|
||||
Importantly, many utilities are still missing. For instance, if you have cameras, you will need to save the images on disk to not go out of RAM, and to do so in threads to not slow down communication with your robot. Also, you will need to store your data in a format optimized for training and web sharing like [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py). More on this in the next section.
|
||||
|
||||
### a. Use the `record` function
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to achieve efficient data recording. It encompasses many recording utilities:
|
||||
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
|
||||
2. Video streams from cameras are displayed in window so that you can verify them.
|
||||
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
|
||||
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
|
||||
5. Set the flow of data recording using command line arguments:
|
||||
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--control.reset_time_s=60` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--control.num_episodes=50` defines the number of episodes to record (50 by default).
|
||||
6. Control the flow during data recording using keyboard keys:
|
||||
- Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
- Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
- Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
7. Similarly to `teleoperate`, you can also use the command line to override anything.
|
||||
|
||||
Before trying `record`, if you want to push your dataset to the hub, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
Also, store your Hugging Face repository name in a variable (e.g. `cadene` or `lerobot`). For instance, run this to use your Hugging Face user name as repository:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
If you don't want to push to hub, use `--control.push_to_hub=false`.
|
||||
|
||||
Now run this to record 2 episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
|
||||
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz)
|
||||
```
|
||||
It contains:
|
||||
- `2024-08-10 15:02:58` which is the date and time of the call to the print function,
|
||||
- `ol_robot.py:219` which is the end of the file name and the line number where the print function is called (`lerobot/scripts/control_robot.py` line `219`).
|
||||
- `dt:33.34 (30.0hz)` which is the "delta time" or the number of milliseconds spent between the previous call to `robot.teleop_step(record_data=True)` and the current one, associated with the frequency (33.34 ms equals 30.0 Hz) ; note that we use `--fps 30` so we expect 30.0 Hz ; when a step takes more time, the line appears in yellow.
|
||||
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
||||
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
||||
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
### b. Advice for recording dataset
|
||||
|
||||
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.
|
||||
|
||||
In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions.
|
||||
|
||||
Avoid adding too much variation too quickly, as it may hinder your results.
|
||||
|
||||
In the coming months, we plan to release a foundational model for robotics. We anticipate that fine-tuning this model will enhance generalization, reducing the need for strict consistency during data collection.
|
||||
|
||||
### c. Visualize all episodes
|
||||
|
||||
You can visualize your dataset by running:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
Note: You might need to add `--local-files-only 1` if your dataset was not uploaded to hugging face hub.
|
||||
|
||||
This will launch a local web server that looks like this:
|
||||
<div style="text-align:center;">
|
||||
<img src="../media/tutorial/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%">
|
||||
</div>
|
||||
|
||||
### d. Replay episode on your robot with the `replay` function
|
||||
|
||||
A useful feature of [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) is the `replay` function, which allows to replay on your robot any episode that you've recorded or episodes from any dataset out there. This function helps you test the repeatability of your robot's actions and assess transferability across robots of the same model.
|
||||
|
||||
To replay the first episode of the dataset you just recorded, run the following command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/koch_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
|
||||
## 4. Train a policy on your data
|
||||
|
||||
### a. Use the `train` script
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/koch_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_koch_test \
|
||||
--job_name=act_koch_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
### b. (Optional) Upload policy checkpoints to the hub
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/act_koch_test \
|
||||
outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
You can also upload intermediate checkpoints with:
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/act_koch_test_${CKPT} \
|
||||
outputs/train/act_koch_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
## 5. Evaluate your policy
|
||||
|
||||
Now that you have a policy checkpoint, you can easily control your robot with it using methods from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) and the policy.
|
||||
|
||||
Try this code for running inference for 60 seconds at 30 fps:
|
||||
```python
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
inference_time_s = 60
|
||||
fps = 30
|
||||
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
||||
|
||||
ckpt_path = "outputs/train/act_koch_test/checkpoints/last/pretrained_model"
|
||||
policy = ACTPolicy.from_pretrained(ckpt_path)
|
||||
policy.to(device)
|
||||
|
||||
for _ in range(inference_time_s * fps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Read the follower state and access the frames from the cameras
|
||||
observation = robot.capture_observation()
|
||||
|
||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||
# with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
# Order the robot to move
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
busy_wait(1 / fps - dt_s)
|
||||
```
|
||||
|
||||
### a. Use our `record` function
|
||||
|
||||
Ideally, when controlling your robot with your neural network, you would want to record evaluation episodes and to be able to visualize them later on, or even train on them like in Reinforcement Learning. This pretty much corresponds to recording a new dataset but with a neural network providing the actions instead of teleoperation.
|
||||
|
||||
To this end, you can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=koch \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/eval_act_koch_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_koch_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_koch_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_koch_test`).
|
||||
|
||||
### b. Visualize evaluation afterwards
|
||||
|
||||
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id ${HF_USER}/eval_act_koch_test
|
||||
```
|
||||
|
||||
## 6. Next step
|
||||
|
||||
Join our [Discord](https://discord.com/invite/s3KuuzsPFb) to collaborate on data collection and help us train a fully open-source foundational models for robotics!
|
||||
38
examples/lekiwi/evaluate.py
Normal file
38
examples/lekiwi/evaluate.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.common.utils.control_utils import predict_action
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
|
||||
NB_CYCLES_CLIENT_CONNECTION = 1000
|
||||
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
robot.connect()
|
||||
|
||||
policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle")
|
||||
policy.reset()
|
||||
|
||||
print("Running inference")
|
||||
i = 0
|
||||
while i < NB_CYCLES_CLIENT_CONNECTION:
|
||||
obs = robot.get_observation()
|
||||
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
obs[key] = value.numpy()
|
||||
|
||||
action_values = predict_action(
|
||||
obs, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
action = {
|
||||
key: action_values[i].item() if isinstance(action_values[i], torch.Tensor) else action_values[i]
|
||||
for i, key in enumerate(robot.action_features)
|
||||
}
|
||||
robot.send_action(action)
|
||||
i += 1
|
||||
|
||||
robot.disconnect()
|
||||
67
examples/lekiwi/record.py
Normal file
67
examples/lekiwi/record.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import time
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.common.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
|
||||
NB_CYCLES_CLIENT_CONNECTION = 250
|
||||
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760431551")
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="user/lekiwi" + str(int(time.time())),
|
||||
fps=10,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
exit()
|
||||
|
||||
print("Starting LeKiwi teleoperation")
|
||||
i = 0
|
||||
while i < NB_CYCLES_CLIENT_CONNECTION:
|
||||
arm_action = leader_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_keys = keyboard.get_action()
|
||||
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
action_sent = robot.send_action(action)
|
||||
observation = robot.get_observation()
|
||||
|
||||
frame = {**action_sent, **observation}
|
||||
task = "Dummy Example Task Dataset"
|
||||
|
||||
dataset.add_frame(frame, task)
|
||||
i += 1
|
||||
|
||||
print("Disconnecting Teleop Devices and LeKiwi Client")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
|
||||
print("Uploading dataset to the hub")
|
||||
dataset.save_episode()
|
||||
dataset.push_to_hub()
|
||||
25
examples/lekiwi/replay.py
Normal file
25
examples/lekiwi/replay.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import time
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
dataset = LeRobotDataset("pepijn223/lekiwi1749025613", episodes=[0])
|
||||
|
||||
robot.connect()
|
||||
|
||||
print("Replaying episode…")
|
||||
for _, action_array in enumerate(dataset.hf_dataset["action"]):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
action = {name: float(action_array[i]) for i, name in enumerate(dataset.features["action"]["names"])}
|
||||
robot.send_action(action)
|
||||
|
||||
busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
print("Disconnecting LeKiwi Client")
|
||||
robot.disconnect()
|
||||
32
examples/lekiwi/teleoperate.py
Normal file
32
examples/lekiwi/teleoperate.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.common.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.common.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
|
||||
|
||||
teleop__arm_config = SO100LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_awesome_leader_arm",
|
||||
)
|
||||
|
||||
teleop_keyboard_config = KeyboardTeleopConfig(
|
||||
id="my_laptop_keyboard",
|
||||
)
|
||||
|
||||
robot = LeKiwiClient(robot_config)
|
||||
teleop_arm = SO100Leader(teleop__arm_config)
|
||||
telep_keyboard = KeyboardTeleop(teleop_keyboard_config)
|
||||
robot.connect()
|
||||
teleop_arm.connect()
|
||||
telep_keyboard.connect()
|
||||
|
||||
while True:
|
||||
observation = robot.get_observation()
|
||||
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_keys = telep_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
robot.send_action(arm_action | base_action)
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.common.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
|
||||
NB_CYCLES_CLIENT_CONNECTION = 250
|
||||
|
||||
|
||||
def main():
|
||||
logging.info("Configuring Teleop Devices")
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760433331")
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
logging.info("Configuring LeKiwi Client")
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
logging.info("Creating LeRobot Dataset")
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="user/lekiwi" + str(int(time.time())),
|
||||
fps=10,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
logging.info("Connecting Teleop Devices")
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
logging.info("Connecting remote LeKiwi")
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
logging.error("Failed to connect to all devices")
|
||||
return
|
||||
|
||||
logging.info("Starting LeKiwi teleoperation")
|
||||
i = 0
|
||||
while i < NB_CYCLES_CLIENT_CONNECTION:
|
||||
arm_action = leader_arm.get_action()
|
||||
base_action = keyboard.get_action()
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
action_sent = robot.send_action(action)
|
||||
observation = robot.get_observation()
|
||||
|
||||
frame = {**action_sent, **observation}
|
||||
task = "Dummy Example Task Dataset"
|
||||
|
||||
logging.info("Saved a frame into the dataset")
|
||||
dataset.add_frame(frame, task)
|
||||
i += 1
|
||||
|
||||
logging.info("Disconnecting Teleop Devices and LeKiwi Client")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
|
||||
logging.info("Uploading dataset to the hub")
|
||||
dataset.save_episode()
|
||||
dataset.push_to_hub()
|
||||
|
||||
logging.info("Finished LeKiwi cleanly")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -168,12 +168,7 @@ available_datasets = sorted(
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
"vqbet",
|
||||
]
|
||||
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||
|
||||
# lists all available robots from `lerobot/common/robot_devices/robots`
|
||||
available_robots = [
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.common.robots import ( # noqa: F401
|
||||
lekiwi,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so100_follower_end_effector,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.common.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
|
||||
@@ -124,10 +124,9 @@ class OpenCVCamera(Camera):
|
||||
self.backend: int = get_cv2_backend()
|
||||
|
||||
if self.height and self.width:
|
||||
self.capture_width, self.capture_height = self.width, self.height
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
|
||||
self.capture_width, self.capture_height = self.height, self.width
|
||||
else:
|
||||
self.capture_width, self.capture_height = self.width, self.height
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.index_or_path})"
|
||||
@@ -206,12 +205,11 @@ class OpenCVCamera(Camera):
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
|
||||
if self.width is None or self.height is None:
|
||||
self.width, self.height = default_width, default_height
|
||||
self.capture_width, self.capture_height = default_width, default_height
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
|
||||
self.width, self.height = default_height, default_width
|
||||
self.capture_width, self.capture_height = default_width, default_height
|
||||
else:
|
||||
self.width, self.height = default_width, default_height
|
||||
self.capture_width, self.capture_height = default_width, default_height
|
||||
else:
|
||||
self._validate_width_and_height()
|
||||
|
||||
|
||||
@@ -138,10 +138,9 @@ class RealSenseCamera(Camera):
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
|
||||
if self.height and self.width:
|
||||
self.capture_width, self.capture_height = self.width, self.height
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
|
||||
self.capture_width, self.capture_height = self.height, self.width
|
||||
else:
|
||||
self.capture_width, self.capture_height = self.width, self.height
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.serial_number})"
|
||||
|
||||
@@ -38,6 +38,7 @@ from lerobot.common.datasets.utils import (
|
||||
DEFAULT_IMAGE_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
_validate_feature_names,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
@@ -314,23 +315,9 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
# if robot is not None:
|
||||
# features = get_features_from_robot(robot, use_videos)
|
||||
# robot_type = robot.robot_type
|
||||
# if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
# logging.warning(
|
||||
# f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
# "In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
# )
|
||||
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
|
||||
@@ -14,13 +14,10 @@
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.robots import RobotConfig
|
||||
from lerobot.common.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@@ -158,125 +155,3 @@ class XarmEnv(EnvConfig):
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class EEActionSpaceConfig:
|
||||
# """Configuration parameters for end-effector action space."""
|
||||
|
||||
# x_step_size: float
|
||||
# y_step_size: float
|
||||
# z_step_size: float
|
||||
# bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
# control_mode: str = "gamepad"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
control_mode: str = "gamepad"
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
teleop: Optional[TeleoperatorConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
type: str = "hil"
|
||||
name: str = "PandaPickCube"
|
||||
task: str = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: Optional[str] = None
|
||||
robot_config: Optional[RobotConfig] = None
|
||||
teleop_config: Optional[TeleoperatorConfig] = None
|
||||
wrapper: Optional[EnvTransformConfig] = None
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -67,8 +65,5 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
|
||||
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
|
||||
# env = ObservationProcessorWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
@@ -47,10 +47,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
@@ -66,18 +62,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
observations["environment_state"]
|
||||
).float()
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -1,589 +0,0 @@
|
||||
# ruff: noqa: N806, N815, N803
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
|
||||
def skew_symmetric(w):
|
||||
"""Creates the skew-symmetric matrix from a 3D vector."""
|
||||
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
|
||||
|
||||
|
||||
def rodrigues_rotation(w, theta):
|
||||
"""Computes the rotation matrix using Rodrigues' formula."""
|
||||
w_hat = skew_symmetric(w)
|
||||
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
|
||||
|
||||
def screw_axis_to_transform(S, theta):
|
||||
"""Converts a screw axis to a 4x4 transformation matrix."""
|
||||
S_w = S[:3]
|
||||
S_v = S[3:]
|
||||
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
|
||||
T = np.eye(4)
|
||||
T[:3, 3] = S_v * theta
|
||||
elif np.linalg.norm(S_w) == 1: # Rotation and translation
|
||||
w_hat = skew_symmetric(S_w)
|
||||
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
|
||||
T = np.eye(4)
|
||||
T[:3, :3] = R
|
||||
T[:3, 3] = t
|
||||
else:
|
||||
raise ValueError("Invalid screw axis parameters")
|
||||
return T
|
||||
|
||||
|
||||
def pose_difference_se3(pose1, pose2):
|
||||
"""
|
||||
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
|
||||
SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, combining rotation (SO(3)) and translation.
|
||||
Each 4x4 matrix has the following structure, a 3x3 rotation matrix in the top-left and a 3x1 translation vector in the top-right:
|
||||
|
||||
[R11 R12 R13 tx]
|
||||
[R21 R22 R23 ty]
|
||||
[R31 R32 R33 tz]
|
||||
[ 0 0 0 1]
|
||||
|
||||
where Rij is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector.
|
||||
|
||||
pose1 - pose2
|
||||
|
||||
Args:
|
||||
pose1: A 4x4 numpy array representing the first pose.
|
||||
pose2: A 4x4 numpy array representing the second pose.
|
||||
|
||||
Returns:
|
||||
A tuple (translation_diff, rotation_diff) where:
|
||||
- translation_diff is a 3x1 numpy array representing the translational difference.
|
||||
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
|
||||
"""
|
||||
|
||||
# Extract rotation matrices from poses
|
||||
R1 = pose1[:3, :3]
|
||||
R2 = pose2[:3, :3]
|
||||
|
||||
# Calculate translational difference
|
||||
translation_diff = pose1[:3, 3] - pose2[:3, 3]
|
||||
|
||||
# Calculate rotational difference using scipy's Rotation library
|
||||
R_diff = Rotation.from_matrix(R1 @ R2.T)
|
||||
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
|
||||
|
||||
return np.concatenate([translation_diff, rotation_diff])
|
||||
|
||||
|
||||
def se3_error(target_pose, current_pose):
|
||||
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
|
||||
R_target = target_pose[:3, :3]
|
||||
R_current = current_pose[:3, :3]
|
||||
R_error = R_target @ R_current.T
|
||||
rot_error = Rotation.from_matrix(R_error).as_rotvec()
|
||||
return np.concatenate([pos_error, rot_error])
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
"""Robot kinematics class supporting multiple robot models."""
|
||||
|
||||
# Robot measurements dictionary
|
||||
ROBOT_MEASUREMENTS = {
|
||||
"koch": {
|
||||
"gripper": [0.239, -0.001, 0.024],
|
||||
"wrist": [0.209, 0, 0.024],
|
||||
"forearm": [0.108, 0, 0.02],
|
||||
"humerus": [0, 0, 0.036],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"so100": {
|
||||
"gripper": [0.320, 0, 0.050],
|
||||
"wrist": [0.278, 0, 0.050],
|
||||
"forearm": [0.143, 0, 0.044],
|
||||
"humerus": [0.031, 0, 0.072],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"moss": {
|
||||
"gripper": [0.246, 0.013, 0.111],
|
||||
"wrist": [0.245, 0.002, 0.064],
|
||||
"forearm": [0.122, 0, 0.064],
|
||||
"humerus": [0.001, 0.001, 0.063],
|
||||
"shoulder": [0, 0, 0],
|
||||
"base": [0, 0, 0.02],
|
||||
},
|
||||
"so101": {
|
||||
"gripper": [0.33, 0.0, 0.285],
|
||||
"wrist": [0.30, 0.0, 0.267],
|
||||
"forearm": [0.25, 0.0, 0.266],
|
||||
"humerus": [0.06, 0.0, 0.264],
|
||||
"shoulder": [0.0, 0.0, 0.238],
|
||||
"base": [0.0, 0.0, 0.12],
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, robot_type="so100"):
|
||||
"""Initialize kinematics for the specified robot type.
|
||||
|
||||
Args:
|
||||
robot_type: String specifying the robot model ("koch", "so100", or "moss")
|
||||
"""
|
||||
if robot_type not in self.ROBOT_MEASUREMENTS:
|
||||
raise ValueError(
|
||||
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
|
||||
)
|
||||
|
||||
self.robot_type = robot_type
|
||||
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
|
||||
|
||||
# Initialize all transformation matrices and screw axes
|
||||
self._setup_transforms()
|
||||
|
||||
def _create_translation_matrix(self, x=0, y=0, z=0):
|
||||
"""Create a 4x4 translation matrix."""
|
||||
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
|
||||
|
||||
def _setup_transforms(self):
|
||||
"""Setup all transformation matrices and screw axes for the robot."""
|
||||
# Set up rotation matrices (constant across robot types)
|
||||
|
||||
# Gripper orientation
|
||||
self.gripper_X0 = np.array(
|
||||
[
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Wrist orientation
|
||||
self.wrist_X0 = np.array(
|
||||
[
|
||||
[0, -1, 0, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Base orientation
|
||||
self.base_X0 = np.array(
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper
|
||||
# Screw axis of gripper frame wrt base frame
|
||||
self.S_BG = np.array(
|
||||
[
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
self.measurements["gripper"][2],
|
||||
-self.measurements["gripper"][1],
|
||||
]
|
||||
)
|
||||
|
||||
# Gripper origin to centroid transform
|
||||
self.X_GoGc = self._create_translation_matrix(x=0.07)
|
||||
|
||||
# Gripper origin to tip transform
|
||||
self.X_GoGt = self._create_translation_matrix(x=0.12)
|
||||
|
||||
# 0-position gripper frame pose wrt base
|
||||
self.X_BoGo = self._create_translation_matrix(
|
||||
x=self.measurements["gripper"][0],
|
||||
y=self.measurements["gripper"][1],
|
||||
z=self.measurements["gripper"][2],
|
||||
)
|
||||
|
||||
# Wrist
|
||||
# Screw axis of wrist frame wrt base frame
|
||||
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
|
||||
|
||||
# 0-position origin to centroid transform
|
||||
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
||||
|
||||
# 0-position wrist frame pose wrt base
|
||||
self.X_BR = self._create_translation_matrix(
|
||||
x=self.measurements["wrist"][0],
|
||||
y=self.measurements["wrist"][1],
|
||||
z=self.measurements["wrist"][2],
|
||||
)
|
||||
|
||||
# Forearm
|
||||
# Screw axis of forearm frame wrt base frame
|
||||
self.S_BF = np.array(
|
||||
[
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
-self.measurements["forearm"][2],
|
||||
0,
|
||||
self.measurements["forearm"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Forearm origin + centroid transform
|
||||
self.X_FoFc = self._create_translation_matrix(x=0.036) # spellchecker:disable-line
|
||||
|
||||
# 0-position forearm frame pose wrt base
|
||||
self.X_BF = self._create_translation_matrix(
|
||||
x=self.measurements["forearm"][0],
|
||||
y=self.measurements["forearm"][1],
|
||||
z=self.measurements["forearm"][2],
|
||||
)
|
||||
|
||||
# Humerus
|
||||
# Screw axis of humerus frame wrt base frame
|
||||
self.S_BH = np.array(
|
||||
[
|
||||
0,
|
||||
-1,
|
||||
0,
|
||||
self.measurements["humerus"][2],
|
||||
0,
|
||||
-self.measurements["humerus"][0],
|
||||
]
|
||||
)
|
||||
|
||||
# Humerus origin to centroid transform
|
||||
self.X_HoHc = self._create_translation_matrix(x=0.0475)
|
||||
|
||||
# 0-position humerus frame pose wrt base
|
||||
self.X_BH = self._create_translation_matrix(
|
||||
x=self.measurements["humerus"][0],
|
||||
y=self.measurements["humerus"][1],
|
||||
z=self.measurements["humerus"][2],
|
||||
)
|
||||
|
||||
# Shoulder
|
||||
# Screw axis of shoulder frame wrt Base frame
|
||||
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
|
||||
|
||||
# Shoulder origin to centroid transform
|
||||
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
|
||||
|
||||
# 0-position shoulder frame pose wrt base
|
||||
self.X_BS = self._create_translation_matrix(
|
||||
x=self.measurements["shoulder"][0],
|
||||
y=self.measurements["shoulder"][1],
|
||||
z=self.measurements["shoulder"][2],
|
||||
)
|
||||
|
||||
# Base
|
||||
# Base origin to centroid transform
|
||||
self.X_BoBc = self._create_translation_matrix(y=0.015)
|
||||
|
||||
# World to base transform
|
||||
self.X_WoBo = self._create_translation_matrix(
|
||||
x=self.measurements["base"][0],
|
||||
y=self.measurements["base"][1],
|
||||
z=self.measurements["base"][2],
|
||||
)
|
||||
|
||||
# Pre-compute gripper post-multiplication matrix
|
||||
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
|
||||
|
||||
def fk_base(self):
|
||||
"""Forward kinematics for the base frame."""
|
||||
return self.X_WoBo @ self.X_BoBc @ self.base_X0
|
||||
|
||||
def fk_shoulder(self, robot_pos_deg):
|
||||
"""Forward kinematics for the shoulder frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
|
||||
|
||||
def fk_humerus(self, robot_pos_deg):
|
||||
"""Forward kinematics for the humerus frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
|
||||
theta_shoulder_pan = robot_pos_rad[0]
|
||||
# NOTE: Negate shoulder lift angle for all robot types
|
||||
theta_shoulder_lift = -robot_pos_rad[1]
|
||||
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, theta_shoulder_pan)
|
||||
@ screw_axis_to_transform(self.S_BH, theta_shoulder_lift)
|
||||
@ self.X_HoHc
|
||||
@ self.X_BH
|
||||
)
|
||||
|
||||
def fk_forearm(self, robot_pos_deg):
|
||||
"""Forward kinematics for the forearm frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
|
||||
theta_shoulder_pan = robot_pos_rad[0]
|
||||
# NOTE: Negate shoulder lift angle for all robot types
|
||||
theta_shoulder_lift = -robot_pos_rad[1]
|
||||
theta_elbow_flex = robot_pos_rad[2]
|
||||
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, theta_shoulder_pan)
|
||||
@ screw_axis_to_transform(self.S_BH, theta_shoulder_lift)
|
||||
@ screw_axis_to_transform(self.S_BF, theta_elbow_flex)
|
||||
@ self.X_FoFc # spellchecker:disable-line
|
||||
@ self.X_BF
|
||||
)
|
||||
|
||||
def fk_wrist(self, robot_pos_deg):
|
||||
"""Forward kinematics for the wrist frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
|
||||
theta_shoulder_pan = robot_pos_rad[0]
|
||||
# NOTE: Negate shoulder lift angle for all robot types
|
||||
theta_shoulder_lift = -robot_pos_rad[1]
|
||||
theta_elbow_flex = robot_pos_rad[2]
|
||||
theta_wrist_flex = robot_pos_rad[3]
|
||||
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, theta_shoulder_pan)
|
||||
@ screw_axis_to_transform(self.S_BH, theta_shoulder_lift)
|
||||
@ screw_axis_to_transform(self.S_BF, theta_elbow_flex)
|
||||
@ screw_axis_to_transform(self.S_BR, theta_wrist_flex)
|
||||
@ self.X_RoRc
|
||||
@ self.X_BR
|
||||
@ self.wrist_X0
|
||||
)
|
||||
|
||||
def fk_gripper(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
|
||||
theta_shoulder_pan = robot_pos_rad[0]
|
||||
# NOTE: Negate shoulder lift angle for all robot types
|
||||
theta_shoulder_lift = -robot_pos_rad[1]
|
||||
theta_elbow_flex = robot_pos_rad[2]
|
||||
theta_wrist_flex = robot_pos_rad[3]
|
||||
theta_wrist_roll = robot_pos_rad[4]
|
||||
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, theta_shoulder_pan)
|
||||
@ screw_axis_to_transform(self.S_BH, theta_shoulder_lift)
|
||||
@ screw_axis_to_transform(self.S_BF, theta_elbow_flex)
|
||||
@ screw_axis_to_transform(self.S_BR, theta_wrist_flex)
|
||||
@ screw_axis_to_transform(self.S_BG, theta_wrist_roll)
|
||||
@ self._fk_gripper_post
|
||||
)
|
||||
|
||||
def fk_gripper_tip(self, robot_pos_deg):
|
||||
"""Forward kinematics for the gripper tip frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
|
||||
theta_shoulder_pan = robot_pos_rad[0]
|
||||
# Negate shoulder lift angle for all robot types
|
||||
theta_shoulder_lift = -robot_pos_rad[1]
|
||||
theta_elbow_flex = robot_pos_rad[2]
|
||||
theta_wrist_flex = robot_pos_rad[3]
|
||||
theta_wrist_roll = robot_pos_rad[4]
|
||||
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, theta_shoulder_pan)
|
||||
@ screw_axis_to_transform(self.S_BH, theta_shoulder_lift)
|
||||
@ screw_axis_to_transform(self.S_BF, theta_elbow_flex)
|
||||
@ screw_axis_to_transform(self.S_BR, theta_wrist_flex)
|
||||
@ screw_axis_to_transform(self.S_BG, theta_wrist_roll)
|
||||
@ self.X_GoGt
|
||||
@ self.X_BoGo
|
||||
@ self.gripper_X0
|
||||
)
|
||||
|
||||
def compute_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(6, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
pose_difference_se3(
|
||||
fk_func(robot_pos_deg[:-1] + delta),
|
||||
fk_func(robot_pos_deg[:-1] - delta),
|
||||
)
|
||||
/ eps
|
||||
)
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
|
||||
"""Finite differences to compute the positional Jacobian.
|
||||
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
|
||||
in the jth joint's velocity.
|
||||
|
||||
Args:
|
||||
robot_pos_deg: Current joint positions in degrees
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
eps = 1e-8
|
||||
jac = np.zeros(shape=(3, 5))
|
||||
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
|
||||
for el_ix in range(len(robot_pos_deg[:-1])):
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
||||
) / eps
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def ik(self, current_joint_pos, desired_ee_pose, position_only=True, fk_func=None):
|
||||
"""Inverse kinematics using gradient descent.
|
||||
|
||||
Args:
|
||||
current_joint_state: Initial joint positions in degrees
|
||||
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
|
||||
position_only: If True, only match end-effector position, not orientation
|
||||
fk_func: Forward kinematics function to use (defaults to fk_gripper)
|
||||
|
||||
Returns:
|
||||
Joint positions in degrees that achieve the desired end-effector pose
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = self.fk_gripper
|
||||
|
||||
# Do gradient descent.
|
||||
current_joint_state = current_joint_pos.copy()
|
||||
max_iterations = 5
|
||||
learning_rate = 1
|
||||
for _ in range(max_iterations):
|
||||
current_ee_pose = fk_func(current_joint_state)
|
||||
if not position_only:
|
||||
error = se3_error(desired_ee_pose, current_ee_pose)
|
||||
jac = self.compute_jacobian(current_joint_state, fk_func)
|
||||
else:
|
||||
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
|
||||
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
|
||||
delta_angles = np.linalg.pinv(jac) @ error
|
||||
current_joint_state[:-1] += learning_rate * delta_angles
|
||||
|
||||
if np.linalg.norm(error) < 5e-3:
|
||||
return current_joint_state
|
||||
return current_joint_state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
def run_test(robot_type):
|
||||
"""Run test suite for a specific robot type."""
|
||||
print(f"\n--- Testing {robot_type.upper()} Robot ---")
|
||||
|
||||
# Initialize kinematics for this robot
|
||||
robot = RobotKinematics(robot_type)
|
||||
|
||||
# Test 1: Forward kinematics consistency
|
||||
print("Test 1: Forward kinematics consistency")
|
||||
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
|
||||
|
||||
# Calculate FK for different joints
|
||||
shoulder_pose = robot.fk_shoulder(test_angles)
|
||||
humerus_pose = robot.fk_humerus(test_angles)
|
||||
forearm_pose = robot.fk_forearm(test_angles)
|
||||
wrist_pose = robot.fk_wrist(test_angles)
|
||||
gripper_pose = robot.fk_gripper(test_angles)
|
||||
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
|
||||
|
||||
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
|
||||
distances = [
|
||||
np.linalg.norm(shoulder_pose[:3, 3]),
|
||||
np.linalg.norm(humerus_pose[:3, 3]),
|
||||
np.linalg.norm(forearm_pose[:3, 3]),
|
||||
np.linalg.norm(wrist_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_pose[:3, 3]),
|
||||
np.linalg.norm(gripper_tip_pose[:3, 3]),
|
||||
]
|
||||
|
||||
# Check if distances generally increase along the chain
|
||||
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
|
||||
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
|
||||
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
|
||||
|
||||
# Test 2: Jacobian computation
|
||||
print("Test 2: Jacobian computation")
|
||||
jacobian = robot.compute_jacobian(test_angles)
|
||||
positional_jacobian = robot.compute_positional_jacobian(test_angles)
|
||||
|
||||
# Check shapes
|
||||
jacobian_shape_ok = jacobian.shape == (6, 5)
|
||||
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
|
||||
|
||||
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
|
||||
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
|
||||
|
||||
# Test 3: Inverse kinematics
|
||||
print("Test 3: Inverse kinematics (position only)")
|
||||
|
||||
# Generate target pose from known joint angles
|
||||
original_angles = np.array([10, 20, 30, -10, 5, 0])
|
||||
target_pose = robot.fk_gripper(original_angles)
|
||||
|
||||
# Start IK from a different position
|
||||
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
# Measure IK performance
|
||||
start_time = time.time()
|
||||
computed_angles = robot.ik(initial_guess.copy(), target_pose)
|
||||
ik_time = time.time() - start_time
|
||||
|
||||
# Compute resulting pose from IK solution
|
||||
result_pose = robot.fk_gripper(computed_angles)
|
||||
|
||||
# Calculate position error
|
||||
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
|
||||
passed = pos_error < 0.01 # Accept errors less than 1cm
|
||||
|
||||
print(f" IK computation time: {ik_time:.4f} seconds")
|
||||
print(f" Position error: {pos_error:.4f}")
|
||||
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
|
||||
|
||||
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
|
||||
|
||||
# Run tests for all robot types
|
||||
results = {}
|
||||
for robot_type in ["koch", "so100", "moss", "so101"]:
|
||||
results[robot_type] = run_test(robot_type)
|
||||
|
||||
# Print overall summary
|
||||
print("\n=== Test Summary ===")
|
||||
all_passed = all(results.values())
|
||||
for robot_type, passed in results.items():
|
||||
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
|
||||
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")
|
||||
@@ -1,41 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("dynamixel")
|
||||
@dataclass
|
||||
class DynamixelMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("feetech")
|
||||
@dataclass
|
||||
class FeetechMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
@@ -39,7 +39,6 @@ DEFAULT_BAUDRATE = 1_000_000
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ class FeetechMotorsBus(MotorsBus):
|
||||
)
|
||||
|
||||
def _assert_same_firmware(self) -> None:
|
||||
firmware_versions = self._read_firmware_version(self.ids)
|
||||
firmware_versions = self._read_firmware_version(self.ids, raise_on_error=True)
|
||||
if len(set(firmware_versions.values())) != 1:
|
||||
raise RuntimeError(
|
||||
"Some Motors use different firmware versions:"
|
||||
@@ -251,7 +251,6 @@ class FeetechMotorsBus(MotorsBus):
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
offsets, mins, maxes = {}, {}, {}
|
||||
drive_modes = dict.fromkeys(self.motors, 0)
|
||||
for motor in self.motors:
|
||||
mins[motor] = self.read("Min_Position_Limit", motor, normalize=False)
|
||||
maxes[motor] = self.read("Max_Position_Limit", motor, normalize=False)
|
||||
@@ -263,7 +262,7 @@ class FeetechMotorsBus(MotorsBus):
|
||||
for motor, m in self.motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=drive_modes[motor],
|
||||
drive_mode=0,
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
@@ -359,13 +358,10 @@ class FeetechMotorsBus(MotorsBus):
|
||||
self.port_handler.setPacketTimeoutMillis((wait_length * tx_time_per_byte) + (3.0 * scs.MAX_ID) + 16.0)
|
||||
|
||||
rxpacket = []
|
||||
while True:
|
||||
while not self.port_handler.isPacketTimeout() and rx_length < wait_length:
|
||||
rxpacket += self.port_handler.readPort(wait_length - rx_length)
|
||||
rx_length = len(rxpacket)
|
||||
|
||||
if self.port_handler.isPacketTimeout(): # or rx_length >= wait_length
|
||||
break
|
||||
|
||||
self.port_handler.is_using = False
|
||||
|
||||
if rx_length == 0:
|
||||
@@ -434,13 +430,13 @@ class FeetechMotorsBus(MotorsBus):
|
||||
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
|
||||
)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
continue
|
||||
|
||||
firm_ver_minor, comm, error = self._read(
|
||||
*FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error
|
||||
)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
continue
|
||||
|
||||
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
|
||||
|
||||
@@ -451,7 +447,7 @@ class FeetechMotorsBus(MotorsBus):
|
||||
for id_ in motor_ids:
|
||||
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
continue
|
||||
|
||||
model_numbers[id_] = model_nb
|
||||
|
||||
|
||||
@@ -38,8 +38,6 @@ from lerobot.common.utils.utils import enter_pressed, move_cursor_up
|
||||
NameOrID: TypeAlias = str | int
|
||||
Value: TypeAlias = int | float
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -241,11 +239,11 @@ class MotorsBus(abc.ABC):
|
||||
)
|
||||
bus.connect()
|
||||
|
||||
position = bus.read("Present_Position", normalize=False)
|
||||
position = bus.read("Present_Position", "my_motor", normalize=False)
|
||||
|
||||
# Move from a few motor steps as an example
|
||||
few_steps = 30
|
||||
bus.write("Goal_Position", position + few_steps, normalize=False)
|
||||
bus.write("Goal_Position", "my_motor", position + few_steps, normalize=False)
|
||||
|
||||
# When done, properly disconnect the port using
|
||||
bus.disconnect()
|
||||
@@ -449,7 +447,7 @@ class MotorsBus(abc.ABC):
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
raise ConnectionError(
|
||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
|
||||
"\nTry running `python lerobot/find_port.py`\n"
|
||||
) from e
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -499,6 +497,7 @@ class MotorsBus(abc.ABC):
|
||||
tqdm.write(f"Motors found for {baudrate=}: {pformat(ids_models, indent=4)}")
|
||||
baudrate_ids[baudrate] = list(ids_models)
|
||||
|
||||
bus.port_handler.closePort()
|
||||
return baudrate_ids
|
||||
|
||||
def setup_motor(
|
||||
@@ -582,8 +581,8 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
Args:
|
||||
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
|
||||
model (str): _description_
|
||||
num_retry (int, optional): _description_. Defaults to 0.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -748,7 +747,9 @@ class MotorsBus(abc.ABC):
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
while True:
|
||||
|
||||
user_pressed_enter = False
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
|
||||
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
|
||||
@@ -760,9 +761,9 @@ class MotorsBus(abc.ABC):
|
||||
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
|
||||
|
||||
if enter_pressed():
|
||||
break
|
||||
user_pressed_enter = True
|
||||
|
||||
if display_values:
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 3)
|
||||
|
||||
@@ -786,23 +787,14 @@ class MotorsBus(abc.ABC):
|
||||
raise ValueError(f"Invalid calibration for motor '{motor}': min and max are equal.")
|
||||
|
||||
bounded_val = min(max_, max(min_, val))
|
||||
# TODO(Steven): normalization can go boom if max_ == min_, we should add a check probably in record_ranges_of_motions
|
||||
# (which probably indicates the user forgot to move a motor, most likely a gripper-like one)
|
||||
if self.motors[motor].norm_mode is MotorNormMode.RANGE_M100_100:
|
||||
norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
|
||||
normalized_values[id_] = -norm if drive_mode else norm
|
||||
elif self.motors[motor].norm_mode is MotorNormMode.RANGE_0_100:
|
||||
norm = ((bounded_val - min_) / (max_ - min_)) * 100
|
||||
normalized_values[id_] = 100 - norm if drive_mode else norm
|
||||
elif self.motors[motor].norm_mode is MotorNormMode.DEGREE:
|
||||
resolution = self.model_resolution_table[self.motors[motor].model]
|
||||
if drive_mode:
|
||||
val *= -1
|
||||
# middle_pos = homing_offset + (resolution - 1) // 2
|
||||
middle_pos = int((max_ + min_) / 2)
|
||||
normalized_values[id_] = ((val - middle_pos) / (resolution // 2)) * 180
|
||||
else:
|
||||
# TODO(alibers): velocity and degree modes
|
||||
# TODO(alibers): degree mode
|
||||
raise NotImplementedError
|
||||
|
||||
return normalized_values
|
||||
@@ -828,15 +820,6 @@ class MotorsBus(abc.ABC):
|
||||
val = 100 - val if drive_mode else val
|
||||
bounded_val = min(100.0, max(0.0, val))
|
||||
unnormalized_values[id_] = int((bounded_val / 100) * (max_ - min_) + min_)
|
||||
elif self.motors[motor].norm_mode is MotorNormMode.DEGREE:
|
||||
resolution = self.model_resolution_table[self.motors[motor].model]
|
||||
middle_pos = int((max_ + min_) / 2)
|
||||
unnormalized_values[id_] = int((val / 180) * resolution // 2) + middle_pos
|
||||
if drive_mode:
|
||||
unnormalized_values[id_] *= -1
|
||||
|
||||
# if unnormalized_values[id_] < 0:
|
||||
# breakpoint()
|
||||
else:
|
||||
# TODO(aliberts): degree mode
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import MotorsBusConfig
|
||||
from .motors_bus import MotorsBus
|
||||
|
||||
|
||||
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||
motors_buses = {}
|
||||
|
||||
for key, cfg in motors_bus_configs.items():
|
||||
if cfg.type == "dynamixel":
|
||||
from .dynamixel import DynamixelMotorsBus
|
||||
|
||||
motors_buses[key] = DynamixelMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "feetech":
|
||||
from lerobot.common.motors.feetech.feetech import FeetechMotorsBus
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return motors_buses
|
||||
|
||||
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
from .configs import DynamixelMotorsBusConfig
|
||||
from .dynamixel import DynamixelMotorsBus
|
||||
|
||||
config = DynamixelMotorsBusConfig(**kwargs)
|
||||
return DynamixelMotorsBus(config)
|
||||
|
||||
elif motor_type == "feetech":
|
||||
from feetech import FeetechMotorsBus
|
||||
|
||||
from .configs import FeetechMotorsBusConfig
|
||||
|
||||
config = FeetechMotorsBusConfig(**kwargs)
|
||||
return FeetechMotorsBus(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
@@ -14,9 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
@@ -45,16 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""
|
||||
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||
For example, you can have one optimizer for the policy and another one for the value function
|
||||
in reinforcement learning settings.
|
||||
|
||||
Returns:
|
||||
The optimizer or a dictionary of optimizers.
|
||||
"""
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -104,76 +94,7 @@ class SGDConfig(OptimizerConfig):
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("multi_adam")
|
||||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||
|
||||
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||
|
||||
Args:
|
||||
lr: Default learning rate (used if not specified for a group)
|
||||
weight_decay: Default weight decay (used if not specified for a group)
|
||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||
grad_clip_norm: Gradient clipping norm
|
||||
"""
|
||||
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
"""
|
||||
optimizers = {}
|
||||
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
# Create optimizer with merged parameters (defaults + group-specific)
|
||||
optimizer_kwargs = {
|
||||
"lr": group_config.get("lr", self.lr),
|
||||
"betas": group_config.get("betas", (0.9, 0.999)),
|
||||
"eps": group_config.get("eps", 1e-5),
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def save_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> None:
|
||||
"""Save optimizer state to disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to save the optimizer state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||
_save_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
# Handle single optimizer
|
||||
_save_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
"""Save a single optimizer's state to disk."""
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
@@ -181,44 +102,11 @@ def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""Load optimizer state from disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to load the optimizer state from.
|
||||
|
||||
Returns:
|
||||
The updated optimizer(s) with loaded state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
loaded_optimizers = {}
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
if optimizer_dir.exists():
|
||||
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
loaded_optimizers[name] = opt
|
||||
return loaded_optimizers
|
||||
else:
|
||||
# Handle single optimizer
|
||||
return _load_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
"""Load a single optimizer's state from disk."""
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||
if "state" in state:
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
else:
|
||||
loaded_state_dict = {"state": {}}
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
|
||||
@@ -15,5 +15,6 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
@@ -27,7 +27,7 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -60,14 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
elif name == "smolvla":
|
||||
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "reward_classifier":
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
return SmolVLAPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -85,8 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -151,7 +151,6 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
@@ -253,168 +252,3 @@ class Unnormalize(nn.Module):
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "reward_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 0.01
|
||||
grad_clip_norm: float = 1.0
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip_norm=self.grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
if not has_image:
|
||||
raise ValueError(
|
||||
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
@@ -1,301 +0,0 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_IMAGE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Optional[Tensor] = None,
|
||||
hidden_states: Optional[Tensor] = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class SpatialLearnedEmbeddings(nn.Module):
|
||||
def __init__(self, height, width, channel, num_features=8):
|
||||
"""
|
||||
PyTorch implementation of learned spatial embeddings
|
||||
|
||||
Args:
|
||||
height: Spatial height of input features
|
||||
width: Spatial width of input features
|
||||
channel: Number of input channels
|
||||
num_features: Number of output embedding dimensions
|
||||
"""
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.channel = channel
|
||||
self.num_features = num_features
|
||||
|
||||
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
|
||||
|
||||
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Forward pass for spatial embedding
|
||||
|
||||
Args:
|
||||
features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch
|
||||
Returns:
|
||||
Output tensor of shape [B, C*F] or [C*F] if no batch
|
||||
"""
|
||||
|
||||
features = features.last_hidden_state
|
||||
|
||||
original_shape = features.shape
|
||||
if features.dim() == 3:
|
||||
features = features.unsqueeze(0) # Add batch dim
|
||||
|
||||
features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1]
|
||||
kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F]
|
||||
|
||||
# Element-wise multiplication and spatial reduction
|
||||
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W
|
||||
|
||||
# Reshape to combine channel and feature dimensions
|
||||
output = output.view(output.size(0), -1) # [B, C*F]
|
||||
|
||||
# Remove batch dim
|
||||
if len(original_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
config_class = RewardClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
|
||||
# Extract image keys from input_features
|
||||
self.image_keys = [
|
||||
key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE)
|
||||
]
|
||||
|
||||
if self.is_cnn:
|
||||
self.encoders = nn.ModuleDict()
|
||||
for image_key in self.image_keys:
|
||||
encoder = self._create_single_encoder()
|
||||
self.encoders[image_key] = encoder
|
||||
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _create_single_encoder(self):
|
||||
encoder = nn.Sequential(
|
||||
self.encoder,
|
||||
SpatialLearnedEmbeddings(
|
||||
height=4,
|
||||
width=4,
|
||||
channel=self.feature_dim,
|
||||
num_features=self.config.image_embedding_pooling_dim,
|
||||
),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
return encoder
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.config.latent_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoders[image_key](x)
|
||||
return outputs
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(x)
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
||||
"""Extract image tensors and label tensors from batch."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
labels = batch["next.reward"]
|
||||
|
||||
return images, labels
|
||||
|
||||
def predict(self, xs: list) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier for inference."""
|
||||
encoder_outputs = torch.hstack(
|
||||
[self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)]
|
||||
)
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
# Get predictions
|
||||
outputs = self.predict(images)
|
||||
|
||||
# Calculate loss
|
||||
if self.config.num_classes == 2:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
# Multi-class classification
|
||||
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
|
||||
# Calculate accuracy for logging
|
||||
correct = (predictions == labels).sum().item()
|
||||
total = labels.size(0)
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
# Return loss and metrics for logging
|
||||
output_dict = {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def predict_reward(self, batch, threshold=0.5):
|
||||
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images from batch dict
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.predict(images).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
pass
|
||||
@@ -1,243 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
"""Check if a feature key represents an image feature.
|
||||
|
||||
Args:
|
||||
key: The feature key to check
|
||||
|
||||
Returns:
|
||||
True if the key represents an image feature, False otherwise
|
||||
"""
|
||||
return key.startswith("observation.image")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
"""Configuration for the concurrency of the actor and learner.
|
||||
Possible values are:
|
||||
- "threads": Use threads for the actor and learner.
|
||||
- "processes": Use processes for the actor and learner.
|
||||
"""
|
||||
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Statistics for normalizing different types of inputs
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
# Device to run the model on (e.g., "cuda", "cpu")
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
# Hidden dimension size for the image encoder
|
||||
image_encoder_hidden_dim: int = 32
|
||||
# Whether to use a shared encoder for actor and critic
|
||||
shared_encoder: bool = True
|
||||
# Number of discrete actions, eg for gripper actions
|
||||
num_discrete_actions: int | None = None
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Seed for the online environment
|
||||
online_env_seed: int = 10000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if is_image_feature(key)]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return None # SAC typically predicts one action at a time
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
154
lerobot/common/policies/smolvla/configuration_smolvla.py
Normal file
154
lerobot/common/policies/smolvla/configuration_smolvla.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla")
|
||||
@dataclass
|
||||
class SmolVLAConfig(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512)
|
||||
|
||||
# Add empty images. Used by smolvla_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = True
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
attention_mode: str = "cross_attn"
|
||||
|
||||
prefix_length: int = -1
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
|
||||
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
|
||||
num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers)
|
||||
self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
|
||||
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
|
||||
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
801
lerobot/common/policies/smolvla/modeling_smolvla.py
Normal file
801
lerobot/common/policies/smolvla/modeling_smolvla.py
Normal file
@@ -0,0 +1,801 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SmolVLA:
|
||||
|
||||
[Paper](https://huggingface.co/papers/2506.01844)
|
||||
|
||||
Designed by Hugging Face.
|
||||
|
||||
Install smolvla extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[smolvla]"
|
||||
```
|
||||
|
||||
Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
|
||||
and an action expert.
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
Example of using the smolvla pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.common.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
return att_2d_masks
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with smolvla which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class SmolVLAPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = SmolVLAConfig
|
||||
name = "smolvla"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
# Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
if f"{key}_padding_mask" in batch:
|
||||
mask = batch[f"{key}_padding_mask"].bool()
|
||||
else:
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
# Create image features not present in the batch
|
||||
# as fully 0 padded images.
|
||||
for num_empty_cameras in range(len(missing_img_keys)):
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
img = torch.ones_like(img) * -1
|
||||
mask = torch.zeros_like(mask)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
tasks = batch["task"]
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT]
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
Efficiently pads a tensor along sequence dimension to match max_len.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
|
||||
max_len (int): Fixed sequence length.
|
||||
pad_value (int/float): Value for padding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
|
||||
"""
|
||||
b, d = tensor.shape[:2]
|
||||
|
||||
# Create a padded tensor of max_len and copy the existing values
|
||||
padded_tensor = torch.full(
|
||||
(b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
padded_tensor[:, :d] = tensor # Efficient in-place copy
|
||||
|
||||
return padded_tensor
|
||||
|
||||
|
||||
class VLAFlowMatching(nn.Module):
|
||||
"""
|
||||
SmolVLA
|
||||
|
||||
[Paper]()
|
||||
|
||||
Designed by Hugging Face.
|
||||
┌──────────────────────────────┐
|
||||
│ actions │
|
||||
│ ▲ │
|
||||
│ ┌─────────┐ ┌─|────┐ │
|
||||
│ | │────► │ │ │
|
||||
│ | │ kv │ │ │
|
||||
│ | │────► │Action│ │
|
||||
│ | VLM │cache │Expert│ |
|
||||
│ │ │────► | │ │
|
||||
│ │ │ │ │ │
|
||||
│ └▲──▲───▲─┘ └───▲──┘ |
|
||||
│ │ | | │ |
|
||||
│ | | | noise │
|
||||
│ │ │ state │
|
||||
│ │ language tokens │
|
||||
│ image(s) │
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.vlm_with_expert = SmolVLMWithExpertModel(
|
||||
model_id=self.config.vlm_model_name,
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
load_vlm_weights=self.config.load_vlm_weights,
|
||||
attention_mode=self.config.attention_mode,
|
||||
num_expert_layers=self.config.num_expert_layers,
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
|
||||
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
|
||||
self.set_requires_grad()
|
||||
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
|
||||
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
|
||||
self.global_image_start_token = torch.tensor(
|
||||
[self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||
)
|
||||
|
||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
self.prefix_length = self.config.prefix_length
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for SmolVLM transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
for _img_idx, (
|
||||
img,
|
||||
img_mask,
|
||||
) in enumerate(zip(images, img_masks, strict=False)):
|
||||
if self.add_image_special_tokens:
|
||||
image_start_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_start_mask = torch.ones_like(
|
||||
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
|
||||
)
|
||||
att_masks += [0] * (image_start_mask.shape[-1])
|
||||
embs.append(image_start_token)
|
||||
pad_masks.append(image_start_mask)
|
||||
|
||||
img_emb = self.vlm_with_expert.embed_image(img)
|
||||
img_emb = img_emb
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask)
|
||||
|
||||
att_masks += [0] * (num_img_embs)
|
||||
if self.add_image_special_tokens:
|
||||
image_end_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_end_mask = torch.ones_like(
|
||||
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
|
||||
)
|
||||
embs.append(image_end_token)
|
||||
pad_masks.append(image_end_mask)
|
||||
att_masks += [0] * (image_end_mask.shape[1])
|
||||
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||
embs.append(state_emb)
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
states_seq_len = state_emb.shape[1]
|
||||
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1] * (states_seq_len)
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
att_masks = att_masks[None, :]
|
||||
|
||||
seq_len = pad_masks.shape[1]
|
||||
if seq_len < self.prefix_length:
|
||||
embs = pad_tensor(embs, self.prefix_length, pad_value=0)
|
||||
pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
|
||||
att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
|
||||
|
||||
att_masks = att_masks.expand(bsize, -1)
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
device = action_emb.device
|
||||
bsize = action_emb.shape[0]
|
||||
dtype = action_emb.dtype
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.vlm_with_expert.expert_hidden_size,
|
||||
self.config.min_period,
|
||||
self.config.max_period,
|
||||
device=device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] * self.config.chunk_size
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
(_, suffix_out), _ = self.vlm_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
# Original openpi code, upcast attention output
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
# Compute image and language key value cache
|
||||
_, past_key_values = self.vlm_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
outputs_embeds, _ = self.vlm_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
550
lerobot/common/policies/smolvla/smolvlm_with_expert.py
Normal file
550
lerobot/common/policies/smolvla/smolvlm_with_expert.py
Normal file
@@ -0,0 +1,550 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
class SmolVLMWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
freeze_vision_encoder: bool = False,
|
||||
attention_mode: str = "self_attn",
|
||||
num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1,
|
||||
self_attn_every_n_layers: int = -1,
|
||||
expert_width_multiplier: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
config = self.vlm.config
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
self.vlm = SmolVLMForConditionalGeneration(config=config)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
if num_vlm_layers > 0:
|
||||
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
|
||||
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
|
||||
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
|
||||
self.config = config
|
||||
# Smaller lm expert
|
||||
lm_expert_config = copy.deepcopy(config.text_config)
|
||||
hidden_size = lm_expert_config.hidden_size
|
||||
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||
if num_expert_layers > 0:
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
)
|
||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||
|
||||
self.num_expert_layers = len(self.lm_expert.layers)
|
||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||
if "cross" in attention_mode:
|
||||
# Reshape qkv projections to have the same input dimension as the vlm
|
||||
for layer_idx in range(len(self.lm_expert.layers)):
|
||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||
continue
|
||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
# Remove unused embed_tokens
|
||||
self.lm_expert.embed_tokens = None
|
||||
|
||||
self.num_attention_heads = self.config.text_config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.text_config.num_key_value_heads
|
||||
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_mode = attention_mode
|
||||
self.expert_hidden_size = lm_expert_config.hidden_size
|
||||
self.set_requires_grad()
|
||||
|
||||
def get_vlm_model(self):
|
||||
return self.vlm.model
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
for params in self.get_vlm_model().vision_model.parameters():
|
||||
params.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
for params in self.vlm.parameters():
|
||||
params.requires_grad = False
|
||||
else:
|
||||
# To avoid unused params issue with distributed training
|
||||
last_layers = [self.num_vlm_layers - 1]
|
||||
if (
|
||||
self.num_vlm_layers != self.num_expert_layers
|
||||
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||
):
|
||||
last_layers.append(self.num_vlm_layers - 2)
|
||||
frozen_layers = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
for layer in last_layers:
|
||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||
|
||||
for name, params in self.vlm.named_parameters():
|
||||
if any(k in name for k in frozen_layers):
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
for name, params in self.lm_expert.named_parameters():
|
||||
if "lm_head" in name:
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
patch_attention_mask = None
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = (
|
||||
self.get_vlm_model()
|
||||
.vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
.last_hidden_state
|
||||
)
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
if hidden_states is None or layer is None:
|
||||
continue
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
seq_len = query_states.shape[1]
|
||||
if seq_len < position_ids.shape[1]:
|
||||
_position_ids = position_ids[:, :seq_len]
|
||||
_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
else:
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
|
||||
attention_mask_ = _attention_mask
|
||||
position_ids_ = _position_ids
|
||||
|
||||
query_states = apply_rope(query_states, position_ids_)
|
||||
key_states = apply_rope(key_states, position_ids_)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_output = attention_interface(
|
||||
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
return [att_output], past_key_values
|
||||
|
||||
def forward_cross_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_outputs = []
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
)
|
||||
|
||||
if len(inputs_embeds) == 2 and not past_key_values:
|
||||
# Prefix attention
|
||||
seq_len = inputs_embeds[0].shape[1]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
|
||||
layer = model_layers[0][layer_idx]
|
||||
|
||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
query_states = apply_rope(query_state, position_id)
|
||||
key_states = apply_rope(key_state, position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
expert_position_id = position_ids
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = past_key_values[layer_idx]["key_states"]
|
||||
value_states = past_key_values[layer_idx]["value_states"]
|
||||
|
||||
# Expert
|
||||
expert_layer = model_layers[1][layer_idx]
|
||||
if expert_layer is not None:
|
||||
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
|
||||
|
||||
expert_input_shape = expert_hidden_states.shape[:-1]
|
||||
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
|
||||
|
||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||
*key_states.shape[:2], -1
|
||||
)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
) # k_proj should have same dim as kv
|
||||
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||
*value_states.shape[:2], -1
|
||||
)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
)
|
||||
|
||||
expert_position_id = (
|
||||
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||
) # start from 0
|
||||
expert_attention_mask = attention_mask[
|
||||
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||
] # take into account kv
|
||||
|
||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
expert_attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
expert_query_states,
|
||||
expert_key_states,
|
||||
expert_value_states,
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
att_outputs.append(None)
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
return att_outputs, past_key_values
|
||||
|
||||
def get_model_layers(self, models: list) -> list:
|
||||
vlm_layers = []
|
||||
expert_layers = []
|
||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||
for i in range(self.num_vlm_layers):
|
||||
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
|
||||
expert_layer = None
|
||||
else:
|
||||
expert_layer_index = i // multiple_of if multiple_of > 0 else i
|
||||
expert_layer = models[1].layers[expert_layer_index]
|
||||
vlm_layers.append(models[0].layers[i])
|
||||
expert_layers.append(expert_layer)
|
||||
return [vlm_layers, expert_layers]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||
model_layers = self.get_model_layers(models)
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.num_vlm_layers
|
||||
head_dim = self.vlm.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
if (
|
||||
fill_kv_cache
|
||||
or "cross" not in self.attention_mode
|
||||
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||
):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
att_output = (
|
||||
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||
) # in case of self_attn
|
||||
if hidden_states is not None:
|
||||
if layer is None:
|
||||
outputs_embeds.append(hidden_states)
|
||||
continue
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
att_out = att_output[:, start:end]
|
||||
out_emb = layer.self_attn.o_proj(att_out)
|
||||
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end if len(att_outputs) == 1 else 0
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.num_attention_heads
|
||||
num_key_value_heads = self.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
|
||||
att_weights = att_weights.to(dtype=torch.float32)
|
||||
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -36,7 +36,7 @@ class LeKiwiConfig(RobotConfig):
|
||||
default_factory=lambda: {
|
||||
"front": OpenCVCameraConfig(index_or_path="/dev/video0", fps=30, width=640, height=480),
|
||||
"wrist": OpenCVCameraConfig(
|
||||
index_or_path="/dev/video2", fps=30, width=640, height=480, rotation=Cv2Rotation.ROTATE_90
|
||||
index_or_path="/dev/video2", fps=30, width=640, height=480, rotation=Cv2Rotation.ROTATE_180
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -88,33 +88,12 @@ The calibration process is very important because it allows a neural network tra
|
||||
### Calibrate follower arm (on mobile base)
|
||||
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
|
||||
<hfoptions id="calibrate_follower">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
--robot.type=lekiwi \
|
||||
--robot.port=/dev/ttyACM0 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_kiwi # <- Give the robot a unique name
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
|
||||
config = LeKiwiClientConfig(
|
||||
remote_ip="192.168.0.23",
|
||||
id="my_awesome_kiwi",
|
||||
)
|
||||
|
||||
lekiwi = LeKiwiClient(config)
|
||||
lekiwi.connect(calibrate=False)
|
||||
lekiwi.calibrate()
|
||||
lekiwi.disconnect()
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video).
|
||||
|
||||
@@ -161,60 +140,11 @@ To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and
|
||||
python -m lerobot.common.robots.lekiwi.lekiwi_host
|
||||
```
|
||||
|
||||
Then on your laptop, also run `conda activate lerobot` and this command or API example:
|
||||
|
||||
<hfoptions id="teleoperate_koch_camera">
|
||||
<hfoption id="Command">
|
||||
Then on your laptop, also run `conda activate lerobot` and run the API example, make sure you set the correct `remote_ip` and `port`.
|
||||
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
--robot.type=lekiwi \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{}" \
|
||||
--robot.id=my_lekiwi \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_blue_leader_arm
|
||||
python examples/lekiwi/teleoperate.py
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.common.teleoperators.keyboard.teleop_keyboard import KeyboardTeleopConfig, KeyboardTeleop
|
||||
from lerobot.common.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
|
||||
from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
|
||||
robot_config = LeKiwiClientConfig(
|
||||
remote_ip="172.18.133.90",
|
||||
id="my_red_lekiwi"
|
||||
)
|
||||
|
||||
teleop__arm_config = SO100LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
)
|
||||
|
||||
teleop_keyboard_config = KeyboardTeleopConfig(
|
||||
id="my_laptop_keyboard",
|
||||
)
|
||||
|
||||
robot = LeKiwiClient(robot_config)
|
||||
teleop_arm = SO100Leader(teleop__arm_config)
|
||||
telep_keyboard = KeyboardTeleop(teleop_keyboard_config)
|
||||
robot.connect()
|
||||
teleop_arm.connect()
|
||||
telep_keyboard.connect()
|
||||
|
||||
while True:
|
||||
observation = robot.get_observation()
|
||||
action_arm = teleop_arm.get_action()
|
||||
action_base = telep_keyboard.get_action()
|
||||
robot.send_action(action_arm | action_base)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port`
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
|
||||
@@ -242,7 +172,69 @@ You should see on your laptop something like this: ```[INFO] Connected to remote
|
||||
### Wired version
|
||||
If you have the **wired** LeKiwi version, please run all commands on your laptop.
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial (you can skip the teleoperation part): [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset.
|
||||
|
||||
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
|
||||
|
||||
Add your token to the CLI by running this command:
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`.
|
||||
```bash
|
||||
python examples/lekiwi/record.py
|
||||
```
|
||||
|
||||
#### Dataset upload
|
||||
Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/so101_test
|
||||
```
|
||||
Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
|
||||
|
||||
#### Tips for gathering data
|
||||
|
||||
Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images.
|
||||
|
||||
In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions.
|
||||
|
||||
Avoid adding too much variation too quickly, as it may hinder your results.
|
||||
|
||||
If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset.
|
||||
|
||||
#### Troubleshooting:
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
|
||||
## Replay an episode
|
||||
|
||||
To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index.
|
||||
|
||||
|
||||
```bash
|
||||
python examples/lekiwi/replay.py
|
||||
```
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
To evaluate your policy run the `evaluate.py` API example, make sure to change `remote_ip`, `port`, model..
|
||||
|
||||
```bash
|
||||
python examples/lekiwi/evaluate.py
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO(aliberts, Steven, Pepijn): use gRPC calls instead of zmq?
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@@ -321,23 +323,11 @@ class LeKiwiClient(Robot):
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
common_keys = [
|
||||
key
|
||||
for key in action
|
||||
if key in (motor.replace("arm_", "") for motor, _ in self.action_features.items())
|
||||
]
|
||||
|
||||
arm_actions = {"arm_" + arm_motor: action[arm_motor] for arm_motor in common_keys}
|
||||
|
||||
keyboard_keys = np.array(list(set(action.keys()) - set(common_keys)))
|
||||
base_actions = self._from_keyboard_to_base_action(keyboard_keys)
|
||||
goal_pos = {**arm_actions, **base_actions}
|
||||
|
||||
self.zmq_cmd_socket.send_string(json.dumps(goal_pos)) # action is in motor space
|
||||
self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space
|
||||
|
||||
# TODO(Steven): Remove the np conversion when it is possible to record a non-numpy array value
|
||||
actions = np.array([goal_pos.get(k, 0.0) for k in self._state_order], dtype=np.float32)
|
||||
return {"action.state": actions}
|
||||
actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32)
|
||||
return {"action": actions}
|
||||
|
||||
def disconnect(self):
|
||||
"""Cleans ZMQ comms"""
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
from .config_so100_follower_end_effector import SO100FollowerEndEffectorConfig
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
@@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from lerobot.common.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower_end_effector")
|
||||
@dataclass
|
||||
class SO100FollowerEndEffectorConfig(RobotConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
|
||||
# Port to connect to the arm
|
||||
port: str
|
||||
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: Dict[str, List[float]] = field(
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
|
||||
max_gripper_pos: float = 50
|
||||
|
||||
end_effector_step_sizes: Dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
|
||||
urdf_path: str = "/Users/michel_aractingi/code/SO-ARM100/Simulation/SO101/so101_new_calib.urdf"
|
||||
@@ -1,203 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.cameras import make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceNotConnectedError
|
||||
from lerobot.common.model.kinematics import RobotKinematics
|
||||
from lerobot.common.motors import Motor, MotorNormMode
|
||||
from lerobot.common.motors.feetech import FeetechMotorsBus
|
||||
|
||||
from ..so100_follower import SO100Follower
|
||||
from .config_so100_follower_end_effector import SO100FollowerEndEffectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO100FollowerEndEffector(SO100Follower):
|
||||
"""
|
||||
SO100Follower robot with end-effector space control.
|
||||
|
||||
This robot inherits from SO100Follower but transforms actions from
|
||||
end-effector space to joint space before sending them to the motors.
|
||||
"""
|
||||
|
||||
config_class = SO100FollowerEndEffectorConfig
|
||||
name = "so100_follower_end_effector"
|
||||
|
||||
def __init__(self, config: SO100FollowerEndEffectorConfig):
|
||||
super().__init__(config)
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREE),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREE),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREE),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREE),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREE),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.config = config
|
||||
|
||||
# Initialize the kinematics module for the so100 robot
|
||||
self.kinematics = RobotKinematics(robot_type="so101")
|
||||
|
||||
# Set the forward kinematics function
|
||||
self.fk_function = self.kinematics.fk_gripper
|
||||
|
||||
# Store the bounds for end-effector position
|
||||
self.end_effector_bounds = self.config.end_effector_bounds
|
||||
|
||||
# Store the joint mins and maxs
|
||||
self.joint_mins = None
|
||||
self.joint_maxs = None
|
||||
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Define action features for end-effector control.
|
||||
Returns dictionary with dtype, shape, and names.
|
||||
"""
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
|
||||
def connect(self):
|
||||
super().connect()
|
||||
self.joint_mins = self.bus.sync_read("Min_Position_Limit")
|
||||
self.joint_maxs = self.bus.sync_read("Max_Position_Limit")
|
||||
|
||||
def send_action(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform action from end-effector space to joint space and send to motors.
|
||||
|
||||
Args:
|
||||
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
|
||||
or a numpy array with [delta_x, delta_y, delta_z]
|
||||
|
||||
Returns:
|
||||
The joint-space action that was sent to the motors
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Convert action to numpy array if not already
|
||||
if isinstance(action, dict):
|
||||
if all(k in action for k in ["delta_x", "delta_y", "delta_z", "gripper"]):
|
||||
action = np.array(
|
||||
[
|
||||
action["delta_x"] * self.config.end_effector_step_sizes["x"],
|
||||
action["delta_y"] * self.config.end_effector_step_sizes["y"],
|
||||
action["delta_z"] * self.config.end_effector_step_sizes["z"],
|
||||
action["gripper"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
|
||||
)
|
||||
action = np.zeros(4, dtype=np.float32)
|
||||
|
||||
if self.current_joint_pos is None:
|
||||
# Read current joint positions
|
||||
current_joint_pos = self.bus.sync_read("Present_Position")
|
||||
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.fk_function(self.current_joint_pos)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values_in_degrees = self.kinematics.ik(
|
||||
self.current_joint_pos,
|
||||
desired_ee_pos,
|
||||
position_only=True,
|
||||
fk_func=self.fk_function,
|
||||
)
|
||||
|
||||
target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0)
|
||||
# Create joint space action dictionary
|
||||
joint_action = {
|
||||
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||
}
|
||||
|
||||
# Handle gripper separately if included in action
|
||||
joint_action["gripper.pos"] = np.clip(
|
||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||
5,
|
||||
self.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values_in_degrees.copy()
|
||||
self.current_joint_pos[-1] = joint_action["gripper.pos"]
|
||||
|
||||
# Send joint space action to parent class
|
||||
return super().send_action(joint_action)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def reset(self):
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
@@ -20,36 +20,6 @@ from lerobot.common.robots import RobotConfig
|
||||
from .robot import Robot
|
||||
|
||||
|
||||
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
if robot_type == "aloha":
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
elif robot_type == "koch_follower":
|
||||
from .koch_follower.config_koch_follower import KochFollowerConfig
|
||||
|
||||
return KochFollowerConfig(**kwargs)
|
||||
elif robot_type == "so100_follower":
|
||||
from .so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
|
||||
return SO100FollowerConfig(**kwargs)
|
||||
elif robot_type == "so100_follower_end_effector":
|
||||
from .so100_follower_end_effector.config_so100_follower_end_effector import (
|
||||
SO100FollowerEndEffectorConfig,
|
||||
)
|
||||
|
||||
return SO100FollowerEndEffectorConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
from .stretch3.configuration_stretch3 import Stretch3RobotConfig
|
||||
|
||||
return Stretch3RobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
from .lekiwi.config_lekiwi import LeKiwiConfig
|
||||
|
||||
return LeKiwiConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
if config.type == "koch_follower":
|
||||
from .koch_follower import KochFollower
|
||||
@@ -59,18 +29,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .so100_follower import SO100Follower
|
||||
|
||||
return SO100Follower(config)
|
||||
elif config.type == "so100_follower_end_effector":
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
|
||||
return SO100FollowerEndEffector(config)
|
||||
elif config.type == "so101_follower":
|
||||
from .so101_follower import SO101Follower
|
||||
|
||||
return SO101Follower(config)
|
||||
elif config.type == "lekiwi":
|
||||
from .lekiwi import LeKiwiClient
|
||||
from .lekiwi import LeKiwi
|
||||
|
||||
return LeKiwiClient(config)
|
||||
return LeKiwi(config)
|
||||
elif config.type == "stretch3":
|
||||
from .stretch3 import Stretch3Robot
|
||||
|
||||
@@ -123,11 +89,3 @@ def ensure_safe_goal_position(
|
||||
)
|
||||
|
||||
return safe_goal_positions
|
||||
|
||||
|
||||
# TODO(aliberts): Remove
|
||||
def get_arm_id(name, arm_type):
|
||||
"""Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
|
||||
like Aloha, it could be left_follower, right_follower, left_leader, or right_leader.
|
||||
"""
|
||||
return f"{name}_{arm_type}"
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
from .teleop_gamepad import GamepadTeleop
|
||||
|
||||
__all__ = ["GamepadTeleopConfig", "GamepadTeleop"]
|
||||
@@ -1,28 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("gamepad")
|
||||
@dataclass
|
||||
class GamepadTeleopConfig(TeleoperatorConfig):
|
||||
# TODO(Steven): Consider setting in here the keys that we want to capture/listen
|
||||
mock: bool = False
|
||||
|
||||
use_gripper: bool = True
|
||||
@@ -1,716 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.model.kinematics import RobotKinematics
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if the user has requested to quit."""
|
||||
return not self.running
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "no-op"
|
||||
elif self.open_gripper_command:
|
||||
return "open"
|
||||
elif self.close_gripper_command:
|
||||
return "close"
|
||||
|
||||
|
||||
class KeyboardController(InputController):
|
||||
"""Generate motion deltas from keyboard input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.key_states = {
|
||||
"forward_x": False,
|
||||
"backward_x": False,
|
||||
"forward_y": False,
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = True
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = True
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def on_release(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = False
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = False
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = False
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
|
||||
self.listener.start()
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
delta_x += self.x_step_size
|
||||
if self.key_states["backward_x"]:
|
||||
delta_x -= self.x_step_size
|
||||
if self.key_states["forward_y"]:
|
||||
delta_y += self.y_step_size
|
||||
if self.key_states["backward_y"]:
|
||||
delta_y -= self.y_step_size
|
||||
if self.key_states["forward_z"]:
|
||||
delta_z += self.z_step_size
|
||||
if self.key_states["backward_z"]:
|
||||
delta_z -= self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if ESC was pressed."""
|
||||
return self.key_states["quit"]
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if Enter was pressed (save episode)."""
|
||||
return self.key_states["success"] or self.key_states["failure"]
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
if pygame.joystick.get_count() == 0:
|
||||
logging.error("No gamepad detected. Please connect a gamepad and try again.")
|
||||
self.running = False
|
||||
return
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
pygame.joystick.quit()
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
x_input = self.joystick.get_axis(0) # Left/Right
|
||||
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -y_input * self.y_step_size # Forward/backward
|
||||
delta_y = -x_input * self.x_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
except pygame.error:
|
||||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_step_size=1.0,
|
||||
y_step_size=1.0,
|
||||
z_step_size=1.0,
|
||||
deadzone=0.1,
|
||||
vendor_id=0x046D,
|
||||
product_id=0xC219,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
vendor_id: USB vendor ID of the gamepad (default: Logitech)
|
||||
product_id: USB product ID of the gamepad (default: RumblePad 2)
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.vendor_id = vendor_id
|
||||
self.product_id = product_id
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
self.quit_requested = False
|
||||
self.save_requested = False
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
|
||||
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
|
||||
)
|
||||
return None
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
return
|
||||
|
||||
try:
|
||||
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
|
||||
self.device = hid.device()
|
||||
self.device.open_path(self.device_info["path"])
|
||||
self.device.set_nonblocking(1)
|
||||
|
||||
manufacturer = self.device.get_manufacturer_string()
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
logging.error("You might need to run this with sudo/admin privileges on some systems")
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_y * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_x * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if quit button was pressed."""
|
||||
return self.quit_requested
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if save button was pressed."""
|
||||
return self.save_requested
|
||||
|
||||
|
||||
def test_forward_kinematics(robot, fps=10):
|
||||
logging.info("Testing Forward Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
logging.info(f"EE Position: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def test_inverse_kinematics(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
desired_ee_pos = ee_pos
|
||||
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Target Joint State: {target_joint_state}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = leader_ee
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Delta End-Effector Control")
|
||||
timestep = time.perf_counter()
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = np.diag(np.ones(4))
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Get leader state for teleoperation
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
# Get current state
|
||||
# obs = robot.capture_observation()
|
||||
# joint_positions = obs["observation.state"].cpu().numpy()
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Calculate delta between leader and follower end-effectors
|
||||
# Scaling factor can be adjusted for sensitivity
|
||||
scaling_factor = 1.0
|
||||
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
|
||||
|
||||
# Apply delta to current position
|
||||
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
|
||||
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
|
||||
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
# Logging
|
||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
|
||||
"""
|
||||
Control a robot using delta end-effector movements from any input controller.
|
||||
|
||||
Args:
|
||||
robot: Robot instance to control
|
||||
controller: InputController instance (keyboard, gamepad, etc.)
|
||||
fps: Control frequency in Hz
|
||||
bounds: Optional position limits
|
||||
fk_func: Forward kinematics function to use
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
|
||||
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Initialize desired position with current position
|
||||
desired_ee_pos = np.eye(4) # Identity matrix
|
||||
|
||||
timestep = time.perf_counter()
|
||||
with controller:
|
||||
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get current robot state
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Update desired position
|
||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
|
||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
|
||||
|
||||
# Apply bounds if provided
|
||||
if bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
|
||||
|
||||
# Only send commands if there's actual movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
"""
|
||||
Control a robot through a gym environment using keyboard inputs.
|
||||
|
||||
Args:
|
||||
env: A gym environment created with make_robot_env
|
||||
fps: Target control frequency
|
||||
"""
|
||||
|
||||
logging.info("Testing Keyboard Control of Gym Environment")
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" ESC: Exit")
|
||||
|
||||
# Reset the environment to get initial observation
|
||||
obs, info = env.reset()
|
||||
|
||||
try:
|
||||
with controller:
|
||||
while not controller.should_quit():
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Create the action vector
|
||||
action = np.array([delta_x, delta_y, delta_z])
|
||||
|
||||
# Skip if no movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Step the environment - pass action as a tensor with intervention flag
|
||||
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
|
||||
|
||||
# Log information
|
||||
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
|
||||
logging.info(f"Reward: {reward}")
|
||||
|
||||
# Reset if episode ended
|
||||
if terminated or truncated:
|
||||
logging.info("Episode ended, resetting environment")
|
||||
obs, info = env.reset()
|
||||
|
||||
# Maintain target frame rate
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
finally:
|
||||
# Close the environment
|
||||
env.close()
|
||||
@@ -1,134 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
|
||||
|
||||
class GripperAction(IntEnum):
|
||||
CLOSE = 0
|
||||
STAY = 1
|
||||
OPEN = 2
|
||||
|
||||
|
||||
gripper_action_map = {
|
||||
"close": GripperAction.CLOSE.value,
|
||||
"open": GripperAction.OPEN.value,
|
||||
"stay": GripperAction.STAY.value,
|
||||
}
|
||||
|
||||
|
||||
class GamepadTeleop(Teleoperator):
|
||||
"""
|
||||
Teleop class to use gamepad inputs for control.
|
||||
"""
|
||||
|
||||
config_class = GamepadTeleopConfig
|
||||
name = "gamepad"
|
||||
|
||||
def __init__(self, config: GamepadTeleopConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.robot_type = config.type
|
||||
|
||||
self.event_queue = Queue()
|
||||
self.current_pressed = {}
|
||||
self.listener = None
|
||||
self.logs = {}
|
||||
|
||||
self.gamepad = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
if self.config.use_gripper:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (3,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def connect(self) -> None:
|
||||
# use HidApi for macos
|
||||
if sys.platform == "darwin":
|
||||
# NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi
|
||||
from lerobot.common.utils.end_effector_control import GamepadControllerHID as Gamepad
|
||||
else:
|
||||
from lerobot.common.utils.end_effector_control import GamepadController as Gamepad
|
||||
|
||||
self.gamepad = Gamepad(x_step_size=1.0, y_step_size=1.0, z_step_size=1.0)
|
||||
self.gamepad.start()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
# Update the controller to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = self.gamepad.get_deltas()
|
||||
|
||||
# Create action from gamepad input
|
||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||
|
||||
action_dict = {
|
||||
"delta_x": gamepad_action[0],
|
||||
"delta_y": gamepad_action[1],
|
||||
"delta_z": gamepad_action[2],
|
||||
}
|
||||
|
||||
# Default gripper action is to stay
|
||||
gripper_action = GripperAction.STAY.value
|
||||
if self.config.use_gripper:
|
||||
gripper_command = self.gamepad.gripper_command()
|
||||
gripper_action = gripper_action_map[gripper_command]
|
||||
action_dict["gripper"] = gripper_action
|
||||
|
||||
return action_dict
|
||||
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def disconnect(self) -> None:
|
||||
pass
|
||||
@@ -44,11 +44,11 @@ class SO101Leader(Teleoperator):
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREE),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREE),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREE),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREE),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREE),
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
|
||||
@@ -45,9 +45,5 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from tests.mocks.mock_teleop import MockTeleop
|
||||
|
||||
return MockTeleop(config)
|
||||
elif config.type == "gamepad":
|
||||
from .gamepad.teleop_gamepad import GamepadTeleop
|
||||
|
||||
return GamepadTeleop(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command:
|
||||
//
|
||||
// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto
|
||||
//
|
||||
// The command should be launched from the root of the project.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package transport;
|
||||
|
||||
// LearnerService: the Actor calls this to push transitions.
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
rpc StreamParameters(Empty) returns (stream Parameters);
|
||||
rpc SendTransitions(stream Transition) returns (Empty);
|
||||
rpc SendInteractions(stream InteractionMessage) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
TransferState transfer_state = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -1,45 +0,0 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: lerobot/common/transport/services.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'lerobot/common/transport/services.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xcc\x02\n\x0eLearnerService\x12I\n\x16SendInteractionMessage\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=305
|
||||
_globals['_TRANSFERSTATE']._serialized_end=401
|
||||
_globals['_TRANSITION']._serialized_start=54
|
||||
_globals['_TRANSITION']._serialized_end=130
|
||||
_globals['_PARAMETERS']._serialized_start=132
|
||||
_globals['_PARAMETERS']._serialized_end=208
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=210
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=294
|
||||
_globals['_EMPTY']._serialized_start=296
|
||||
_globals['_EMPTY']._serialized_end=303
|
||||
_globals['_LEARNERSERVICE']._serialized_start=404
|
||||
_globals['_LEARNERSERVICE']._serialized_end=736
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,276 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class LearnerServiceStub:
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendInteractionMessage = channel.unary_unary(
|
||||
'/transport.LearnerService/SendInteractionMessage',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.StreamParameters = channel.unary_stream(
|
||||
'/transport.LearnerService/StreamParameters',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTransitions = channel.stream_unary(
|
||||
'/transport.LearnerService/SendTransitions',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractions = channel.stream_unary(
|
||||
'/transport.LearnerService/SendInteractions',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/transport.LearnerService/Ready',
|
||||
request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class LearnerServiceServicer:
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""Actor -> Learner to store transitions
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTransitions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractions(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendInteractionMessage,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'StreamParameters': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamParameters,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString,
|
||||
),
|
||||
'SendTransitions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendTransitions,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractions': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendInteractions,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'transport.LearnerService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class LearnerService:
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendInteractionMessage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.LearnerService/SendInteractionMessage',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def StreamParameters(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/transport.LearnerService/StreamParameters',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendTransitions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.LearnerService/SendTransitions',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractions(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.LearnerService/SendInteractions',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.LearnerService/Ready',
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,142 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.common.utils.transition import Transition
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
buffer.seek(0)
|
||||
return result
|
||||
|
||||
|
||||
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = services_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logging.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logging.debug(f"{log_prefix} Received item")
|
||||
if shutdown_event.is_set():
|
||||
logging.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step 0")
|
||||
step = 0
|
||||
continue
|
||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == services_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
step = 0
|
||||
|
||||
logging.debug(f"{log_prefix} Queue updated")
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
|
||||
# This is currently safe as we only deserialize trusted internal data.
|
||||
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
transitions = torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
|
||||
# Add validation checks here
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
@@ -1,816 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from contextlib import suppress
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.transition import Transition
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape # noqa: N806
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Replay buffer for storing transitions.
|
||||
It will allocate tensors on the specified device, when the first transition is added.
|
||||
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
|
||||
and use the `storage_device` flag to store the buffer on a different device.
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
Using "cpu" can help save GPU memory.
|
||||
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||
"""
|
||||
if capacity <= 0:
|
||||
raise ValueError("Capacity must be greater than 0.")
|
||||
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
self.initialized = False
|
||||
self.optimize_memory = optimize_memory
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
||||
self.image_augmentation_function = image_augmentation_function
|
||||
|
||||
if image_augmentation_function is None:
|
||||
base_function = functools.partial(random_shift, pad=4)
|
||||
self.image_augmentation_function = torch.compile(base_function)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def _initialize_storage(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
# Determine shapes from the first transition
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
action_shape = action.squeeze(0).shape
|
||||
|
||||
# Pre-allocate tensors for storage
|
||||
self.states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach: store states and next_states separately
|
||||
self.next_states = {
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
else:
|
||||
# Memory-optimized approach: don't allocate next_states buffer
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize storage for complementary_info
|
||||
self.has_complementary_info = complementary_info is not None
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info = {}
|
||||
|
||||
if self.has_complementary_info:
|
||||
self.complementary_info_keys = list(complementary_info.keys())
|
||||
# Pre-allocate tensors for each key in complementary_info
|
||||
for key, value in complementary_info.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
value_shape = value.squeeze(0).shape
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
truncated: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
# Store the complementary_info
|
||||
for key in self.complementary_info_keys:
|
||||
if key in complementary_info:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
|
||||
# First pass: load all state tensors to target device
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach - load next_states directly
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
else:
|
||||
# Memory-optimized approach - get next_state from the next index
|
||||
next_idx = (idx + 1) % self.capacity
|
||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||
|
||||
# Apply image augmentation in a batched way if needed
|
||||
if self.use_drq and image_keys:
|
||||
# Concatenate all images from state and next_state
|
||||
all_images = []
|
||||
for key in image_keys:
|
||||
all_images.append(batch_state[key])
|
||||
all_images.append(batch_next_state[key])
|
||||
|
||||
# Optimization: Batch all images and apply augmentation once
|
||||
all_images_tensor = torch.cat(all_images, dim=0)
|
||||
augmented_images = self.image_augmentation_function(all_images_tensor)
|
||||
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# Calculate offsets for the current image key:
|
||||
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
|
||||
# States start at index i*2*batch_size and take up batch_size slots
|
||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
if self.has_complementary_info:
|
||||
batch_complementary_info = {}
|
||||
for key in self.complementary_info_keys:
|
||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
complementary_info=batch_complementary_info,
|
||||
)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""
|
||||
Creates an infinite iterator that yields batches of transitions.
|
||||
Will automatically restart when internal iterator is exhausted.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True)
|
||||
queue_size (int): Number of batches to prefetch (default: 2)
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batched transitions
|
||||
"""
|
||||
while True: # Create an infinite loop
|
||||
if async_prefetch:
|
||||
# Get the standard iterator
|
||||
iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size)
|
||||
else:
|
||||
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||
|
||||
# Yield all items from the iterator
|
||||
with suppress(StopIteration):
|
||||
yield from iterator
|
||||
|
||||
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Creates an iterator that prefetches batches in a background thread.
|
||||
|
||||
Args:
|
||||
queue_size (int): Number of batches to prefetch (default: 2)
|
||||
batch_size (int): Size of batches to sample (default: 128)
|
||||
|
||||
Yields:
|
||||
BatchTransition: Prefetched batch transitions
|
||||
"""
|
||||
import queue
|
||||
import threading
|
||||
|
||||
# Use thread-safe queue
|
||||
data_queue = queue.Queue(maxsize=queue_size)
|
||||
running = [True] # Use list to allow modification in nested function
|
||||
|
||||
def prefetch_worker():
|
||||
while running[0]:
|
||||
try:
|
||||
# Sample data and add to queue
|
||||
data = self.sample(batch_size)
|
||||
data_queue.put(data, block=True, timeout=0.5)
|
||||
except queue.Full:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Prefetch error: {e}")
|
||||
break
|
||||
|
||||
# Start prefetching thread
|
||||
thread = threading.Thread(target=prefetch_worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while running[0]:
|
||||
try:
|
||||
yield data_queue.get(block=True, timeout=0.5)
|
||||
except queue.Empty:
|
||||
if not thread.is_alive():
|
||||
break
|
||||
finally:
|
||||
# Clean up
|
||||
running[0] = False
|
||||
thread.join(timeout=1.0)
|
||||
|
||||
def _get_naive_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
Creates a simple non-threaded iterator that yields batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of batches to sample
|
||||
queue_size (int): Number of initial batches to prefetch
|
||||
|
||||
Yields:
|
||||
BatchTransition: Batch transitions
|
||||
"""
|
||||
import collections
|
||||
|
||||
queue = collections.deque()
|
||||
|
||||
def enqueue(n):
|
||||
for _ in range(n):
|
||||
data = self.sample(batch_size)
|
||||
queue.append(data)
|
||||
|
||||
enqueue(queue_size)
|
||||
while queue:
|
||||
yield queue.popleft()
|
||||
enqueue(1)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
optimize_memory: bool = False,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
|
||||
image_augmentation_function (Optional[Callable]): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory.
|
||||
optimize_memory (bool): If True, reduces memory usage by not duplicating state data.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with dataset transitions.
|
||||
"""
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
# Create replay buffer with image augmentation and DrQ settings
|
||||
replay_buffer = cls(
|
||||
capacity=capacity,
|
||||
device=device,
|
||||
state_keys=state_keys,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
use_drq=use_drq,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=optimize_memory,
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
if (
|
||||
"complementary_info" in first_transition
|
||||
and first_transition["complementary_info"] is not None
|
||||
):
|
||||
first_complementary_info = {
|
||||
k: v.to(device) for k, v in first_transition["complementary_info"].items()
|
||||
}
|
||||
|
||||
replay_buffer._initialize_storage(
|
||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||
)
|
||||
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(storage_device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
complementary_info=data.get("complementary_info", None),
|
||||
)
|
||||
|
||||
return replay_buffer
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1,
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object.
|
||||
"""
|
||||
if self.size == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Create features dictionary for the dataset
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.states:
|
||||
sample_val = self.states[key][0]
|
||||
f_info = guess_feature_info(t=sample_val, name=key)
|
||||
features[key] = f_info
|
||||
|
||||
# Add complementary_info keys if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
sample_val = self.complementary_info[key][0]
|
||||
if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0:
|
||||
sample_val = sample_val.unsqueeze(0)
|
||||
f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}")
|
||||
features[f"complementary_info.{key}"] = f_info
|
||||
|
||||
# Create an empty LeRobotDataset
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
# Start writing images if needed
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for idx in range(self.size):
|
||||
actual_idx = (self.position - self.size + idx) % self.capacity
|
||||
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.states:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add complementary_info if available
|
||||
if self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
val = self.complementary_info[key][actual_idx]
|
||||
# Convert tensors to CPU
|
||||
if isinstance(val, torch.Tensor):
|
||||
if val.ndim == 0:
|
||||
val = val.unsqueeze(0)
|
||||
frame_dict[f"complementary_info.{key}"] = val.cpu()
|
||||
# Non-tensor values can be used directly
|
||||
else:
|
||||
frame_dict[f"complementary_info.{key}"] = val
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict, task=task_name)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||
lerobot_dataset.save_episode()
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
|
||||
# Save any remaining frames in the buffer
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||
has_complementary_info = len(complementary_info_keys) > 0
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
|
||||
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
||||
if has_done_key:
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
else:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as done
|
||||
done = False
|
||||
if i == num_frames - 1:
|
||||
done = True
|
||||
elif i < num_frames - 1:
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||
done = True
|
||||
|
||||
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
||||
truncated = done
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- 5) Complementary info (if available) -----
|
||||
complementary_info = None
|
||||
if has_complementary_info:
|
||||
complementary_info = {}
|
||||
for key in complementary_info_keys:
|
||||
# Strip the "complementary_info." prefix to get the actual key
|
||||
clean_key = key[len("complementary_info.") :]
|
||||
val = current_sample[key]
|
||||
# Handle tensor and non-tensor values differently
|
||||
if isinstance(val, torch.Tensor):
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
else:
|
||||
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
||||
# For non-tensor values, use directly
|
||||
complementary_info[clean_key] = val
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to appropriate dtype for numeric.
|
||||
"""
|
||||
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
return {
|
||||
"dtype": "image",
|
||||
"shape": shape,
|
||||
}
|
||||
else:
|
||||
# Otherwise treat as numeric
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
# Concatenate state fields
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["state"][key], right_batch_transition["state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
|
||||
# Concatenate next_state fields
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
|
||||
# Concatenate done and truncated fields
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
left_batch_transitions["truncated"] = torch.cat(
|
||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Handle complementary_info
|
||||
left_info = left_batch_transitions.get("complementary_info")
|
||||
right_info = right_batch_transition.get("complementary_info")
|
||||
|
||||
# Only process if right_info exists
|
||||
if right_info is not None:
|
||||
# Initialize left complementary_info if needed
|
||||
if left_info is None:
|
||||
left_batch_transitions["complementary_info"] = right_info
|
||||
else:
|
||||
# Concatenate each field
|
||||
for key in right_info:
|
||||
if key in left_info:
|
||||
left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0)
|
||||
else:
|
||||
left_info[key] = right_info[key]
|
||||
|
||||
return left_batch_transitions
|
||||
@@ -98,7 +98,12 @@ def is_headless():
|
||||
|
||||
|
||||
def predict_action(
|
||||
observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, use_amp: bool
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
@@ -114,6 +119,9 @@ def predict_action(
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
@@ -1,802 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.model.kinematics import RobotKinematics
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if the user has requested to quit."""
|
||||
return not self.running
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "stay"
|
||||
elif self.open_gripper_command:
|
||||
return "open"
|
||||
elif self.close_gripper_command:
|
||||
return "close"
|
||||
|
||||
|
||||
class KeyboardController(InputController):
|
||||
"""Generate motion deltas from keyboard input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.key_states = {
|
||||
"forward_x": False,
|
||||
"backward_x": False,
|
||||
"forward_y": False,
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = True
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = True
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def on_release(key):
|
||||
try:
|
||||
if key == keyboard.Key.up:
|
||||
self.key_states["forward_x"] = False
|
||||
elif key == keyboard.Key.down:
|
||||
self.key_states["backward_x"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
self.key_states["forward_y"] = False
|
||||
elif key == keyboard.Key.right:
|
||||
self.key_states["backward_y"] = False
|
||||
elif key == keyboard.Key.shift:
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
|
||||
self.listener.start()
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
delta_x += self.x_step_size
|
||||
if self.key_states["backward_x"]:
|
||||
delta_x -= self.x_step_size
|
||||
if self.key_states["forward_y"]:
|
||||
delta_y += self.y_step_size
|
||||
if self.key_states["backward_y"]:
|
||||
delta_y -= self.y_step_size
|
||||
if self.key_states["forward_z"]:
|
||||
delta_z += self.z_step_size
|
||||
if self.key_states["backward_z"]:
|
||||
delta_z -= self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if ESC was pressed."""
|
||||
return self.key_states["quit"]
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if Enter was pressed (save episode)."""
|
||||
return self.key_states["success"] or self.key_states["failure"]
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
if pygame.joystick.get_count() == 0:
|
||||
logging.error("No gamepad detected. Please connect a gamepad and try again.")
|
||||
self.running = False
|
||||
return
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
pygame.joystick.quit()
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
x_input = self.joystick.get_axis(0) # Left/Right
|
||||
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -y_input * self.y_step_size # Forward/backward
|
||||
delta_y = -x_input * self.x_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
except pygame.error:
|
||||
logging.error("Error reading gamepad. Is it still connected?")
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_step_size=1.0,
|
||||
y_step_size=1.0,
|
||||
z_step_size=1.0,
|
||||
deadzone=0.1,
|
||||
vendor_id=0x046D,
|
||||
product_id=0xC219,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
vendor_id: USB vendor ID of the gamepad (default: Logitech)
|
||||
product_id: USB product ID of the gamepad (default: RumblePad 2)
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.vendor_id = vendor_id
|
||||
self.product_id = product_id
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
self.quit_requested = False
|
||||
self.save_requested = False
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
|
||||
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
|
||||
)
|
||||
return None
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
return
|
||||
|
||||
try:
|
||||
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
|
||||
self.device = hid.device()
|
||||
self.device.open_path(self.device_info["path"])
|
||||
self.device.set_nonblocking(1)
|
||||
|
||||
manufacturer = self.device.get_manufacturer_string()
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
logging.error("You might need to run this with sudo/admin privileges on some systems")
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_y * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_x * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_quit(self):
|
||||
"""Return True if quit button was pressed."""
|
||||
return self.quit_requested
|
||||
|
||||
def should_save(self):
|
||||
"""Return True if save button was pressed."""
|
||||
return self.save_requested
|
||||
|
||||
|
||||
def test_forward_kinematics(robot, fps=10):
|
||||
logging.info("Testing Forward Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
logging.info(f"EE Position: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def test_inverse_kinematics(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
desired_ee_pos = ee_pos
|
||||
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Target Joint State: {target_joint_state}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Inverse Kinematics")
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
timestep = time.perf_counter()
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = leader_ee
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Leader EE: {leader_ee[:3, 3]}, Follower EE: {ee_pos[:3, 3]}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
logging.info("Testing Delta End-Effector Control")
|
||||
timestep = time.perf_counter()
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
initial_leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
desired_ee_pos = np.diag(np.ones(4))
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
fixed_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
while time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Get leader state for teleoperation
|
||||
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
|
||||
leader_ee = kinematics.fk_gripper_tip(leader_joint_positions)
|
||||
|
||||
# Get current state
|
||||
# obs = robot.capture_observation()
|
||||
# joint_positions = obs["observation.state"].cpu().numpy()
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Calculate delta between leader and follower end-effectors
|
||||
# Scaling factor can be adjusted for sensitivity
|
||||
scaling_factor = 1.0
|
||||
ee_delta = -np.clip((leader_ee - initial_leader_ee) * scaling_factor, -0.05, 0.05)
|
||||
|
||||
# Apply delta to current position
|
||||
desired_ee_pos[0, 3] = fixed_ee_pos[0, 3] # current_ee_pos[0, 3] + ee_delta[0, 3] * 0
|
||||
desired_ee_pos[1, 3] = fixed_ee_pos[1, 3] # current_ee_pos[1, 3] + ee_delta[1, 3] * 0
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] - ee_delta[2, 3]
|
||||
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True, fk_func=kinematics.fk_gripper_tip
|
||||
)
|
||||
|
||||
initial_leader_ee = leader_ee.copy()
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
# Logging
|
||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
|
||||
"""
|
||||
Control a robot using delta end-effector movements from any input controller.
|
||||
|
||||
Args:
|
||||
robot: Robot instance to control
|
||||
controller: InputController instance (keyboard, gamepad, etc.)
|
||||
fps: Control frequency in Hz
|
||||
bounds: Optional position limits
|
||||
fk_func: Forward kinematics function to use
|
||||
"""
|
||||
if fk_func is None:
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
|
||||
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
kinematics = RobotKinematics(robot.robot_type)
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Initialize desired position with current position
|
||||
desired_ee_pos = np.eye(4) # Identity matrix
|
||||
|
||||
timestep = time.perf_counter()
|
||||
with controller:
|
||||
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get current robot state
|
||||
joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = kinematics.fk_gripper_tip(joint_positions)
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Update desired position
|
||||
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
|
||||
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
|
||||
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
|
||||
|
||||
# Apply bounds if provided
|
||||
if bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
|
||||
|
||||
# Only send commands if there's actual movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Compute joint targets via inverse kinematics
|
||||
target_joint_state = kinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
|
||||
# Send command to robot
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
"""
|
||||
Control a robot through a gym environment using keyboard inputs.
|
||||
|
||||
Args:
|
||||
env: A gym environment created with make_robot_env
|
||||
fps: Target control frequency
|
||||
"""
|
||||
|
||||
logging.info("Testing Keyboard Control of Gym Environment")
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" ESC: Exit")
|
||||
|
||||
# Reset the environment to get initial observation
|
||||
obs, info = env.reset()
|
||||
|
||||
try:
|
||||
with controller:
|
||||
while not controller.should_quit():
|
||||
loop_start_time = time.perf_counter()
|
||||
|
||||
# Process input events
|
||||
controller.update()
|
||||
|
||||
# Get movement deltas from the controller
|
||||
delta_x, delta_y, delta_z = controller.get_deltas()
|
||||
|
||||
# Create the action vector
|
||||
action = np.array([delta_x, delta_y, delta_z])
|
||||
|
||||
# Skip if no movement
|
||||
if any(abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]):
|
||||
# Step the environment - pass action as a tensor with intervention flag
|
||||
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
|
||||
|
||||
# Log information
|
||||
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
|
||||
logging.info(f"Reward: {reward}")
|
||||
|
||||
# Reset if episode ended
|
||||
if terminated or truncated:
|
||||
logging.info("Episode ended, resetting environment")
|
||||
obs, info = env.reset()
|
||||
|
||||
# Maintain target frame rate
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
finally:
|
||||
# Close the environment
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from lerobot.common.envs.configs import EEActionSpaceConfig, EnvTransformConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="keyboard",
|
||||
choices=[
|
||||
"keyboard",
|
||||
"gamepad",
|
||||
"keyboard_gym",
|
||||
"gamepad_gym",
|
||||
"leader_delta",
|
||||
"leader",
|
||||
],
|
||||
help="Control mode to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-type",
|
||||
type=str,
|
||||
default="so100",
|
||||
help="Robot type (so100, koch, aloha, etc.)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_config = RobotConfig.get_choice_class(args.robot_type)(mock=False)
|
||||
robot = make_robot_from_config(robot_config)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
# Example bounds
|
||||
bounds = {
|
||||
"max": np.array([0.32170487, 0.201285, 0.10273342]),
|
||||
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
|
||||
}
|
||||
|
||||
try:
|
||||
# Determine controller type based on mode prefix
|
||||
controller = None
|
||||
if args.mode.startswith("keyboard"):
|
||||
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
elif args.mode.startswith("gamepad"):
|
||||
if sys.platform == "darwin":
|
||||
controller = GamepadControllerHID(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
else:
|
||||
controller = GamepadController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
|
||||
# Handle mode categories
|
||||
if args.mode in ["keyboard", "gamepad"]:
|
||||
# Direct robot control modes
|
||||
teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
|
||||
|
||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||
# Gym environment control modes
|
||||
cfg = HILSerlRobotEnvConfig(robot=robot_config, wrapper=EnvTransformConfig())
|
||||
cfg.wrapper.ee_action_space_params = EEActionSpaceConfig(
|
||||
x_step_size=0.03, y_step_size=0.03, z_step_size=0.03, bounds=bounds
|
||||
)
|
||||
cfg.wrapper.ee_action_space_params.use_gamepad = False
|
||||
cfg.device = "cpu"
|
||||
env = make_robot_env(cfg, robot)
|
||||
teleoperate_gym_env(env, controller, fps=cfg.fps)
|
||||
|
||||
elif args.mode == "leader_delta":
|
||||
# Leader-follower modes don't use controllers
|
||||
teleoperate_delta_inverse_kinematics_with_leader(robot)
|
||||
|
||||
elif args.mode == "leader":
|
||||
teleoperate_inverse_kinematics_with_leader(robot)
|
||||
|
||||
finally:
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
@@ -1,53 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
shutdown_event_counter = 0
|
||||
|
||||
|
||||
def setup_process_handlers(use_threads: bool) -> any:
|
||||
if use_threads:
|
||||
from threading import Event
|
||||
else:
|
||||
from multiprocessing import Event
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Define signal handler
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
global shutdown_event_counter
|
||||
shutdown_event_counter += 1
|
||||
|
||||
if shutdown_event_counter > 1:
|
||||
logging.info("Force shutdown")
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill)
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logging.info("Shutdown signal received. Cleaning up...")
|
||||
shutdown_event.set()
|
||||
|
||||
return shutdown_event
|
||||
@@ -1,35 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from queue import Empty, Queue
|
||||
|
||||
|
||||
def get_last_item_from_queue(queue: Queue):
|
||||
item = queue.get()
|
||||
counter = 1
|
||||
|
||||
# Drain queue and keep only the most recent parameters
|
||||
try:
|
||||
while True:
|
||||
item = queue.get_nowait()
|
||||
counter += 1
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
logging.debug(f"Drained {counter} items from queue")
|
||||
|
||||
return item
|
||||
@@ -1,85 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
truncated: bool
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
device = torch.device(device)
|
||||
non_blocking = device.type == "cuda"
|
||||
|
||||
# Move state tensors to device
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to device
|
||||
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move reward and done if they are tensors
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=non_blocking)
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)
|
||||
|
||||
# Move next_state tensors to device
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# Move complementary_info tensors if present
|
||||
if transition.get("complementary_info") is not None:
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
@@ -20,11 +20,9 @@ import platform
|
||||
import select
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from copy import copy, deepcopy
|
||||
from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -111,17 +109,11 @@ def is_amp_available(device: str):
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -135,12 +127,6 @@ def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# Additionally write logs to file
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -253,114 +239,3 @@ def enter_pressed() -> bool:
|
||||
def move_cursor_up(lines):
|
||||
"""Move the cursor up by a specified number of lines."""
|
||||
print(f"\033[{lines}A", end="")
|
||||
|
||||
|
||||
class TimerManager:
|
||||
"""
|
||||
Lightweight utility to measure elapsed time.
|
||||
|
||||
Examples
|
||||
--------
|
||||
```python
|
||||
# Example 1: Using context manager
|
||||
timer = TimerManager("Policy", log=False)
|
||||
for _ in range(3):
|
||||
with timer:
|
||||
time.sleep(0.01)
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
|
||||
```python
|
||||
# Example 2: Using start/stop methods
|
||||
timer = TimerManager("Policy", log=False)
|
||||
timer.start()
|
||||
time.sleep(0.01)
|
||||
timer.stop()
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str = "Elapsed-time",
|
||||
log: bool = True,
|
||||
logger: logging.Logger | None = None,
|
||||
):
|
||||
self.label = label
|
||||
self.log = log
|
||||
self.logger = logger
|
||||
self._start: float | None = None
|
||||
self._history: list[float] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
self._start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def stop(self) -> float:
|
||||
if self._start is None:
|
||||
raise RuntimeError("Timer was never started.")
|
||||
elapsed = time.perf_counter() - self._start
|
||||
self._history.append(elapsed)
|
||||
self._start = None
|
||||
if self.log:
|
||||
if self.logger is not None:
|
||||
self.logger.info(f"{self.label}: {elapsed:.6f} s")
|
||||
else:
|
||||
logging.info(f"{self.label}: {elapsed:.6f} s")
|
||||
return elapsed
|
||||
|
||||
def reset(self):
|
||||
self._history.clear()
|
||||
|
||||
@property
|
||||
def last(self) -> float:
|
||||
return self._history[-1] if self._history else 0.0
|
||||
|
||||
@property
|
||||
def avg(self) -> float:
|
||||
return mean(self._history) if self._history else 0.0
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return sum(self._history)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._history)
|
||||
|
||||
@property
|
||||
def history(self) -> list[float]:
|
||||
return deepcopy(self._history)
|
||||
|
||||
@property
|
||||
def fps_history(self) -> list[float]:
|
||||
return [1.0 / t for t in self._history]
|
||||
|
||||
@property
|
||||
def fps_last(self) -> float:
|
||||
return 0.0 if self.last == 0 else 1.0 / self.last
|
||||
|
||||
@property
|
||||
def fps_avg(self) -> float:
|
||||
return 0.0 if self.avg == 0 else 1.0 / self.avg
|
||||
|
||||
def percentile(self, p: float) -> float:
|
||||
"""
|
||||
Return the p-th percentile of recorded times.
|
||||
"""
|
||||
if not self._history:
|
||||
return 0.0
|
||||
return float(np.percentile(self._history, p))
|
||||
|
||||
def fps_percentile(self, p: float) -> float:
|
||||
"""
|
||||
FPS corresponding to the p-th percentile time.
|
||||
"""
|
||||
val = self.percentile(p)
|
||||
return 0.0 if val == 0 else 1.0 / val
|
||||
|
||||
@@ -30,10 +30,9 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
@@ -93,12 +92,6 @@ class WandBLogger:
|
||||
resume="must" if cfg.resume else None,
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
run_id = wandb.run.id
|
||||
# NOTE: We will override the cfg.wandb.run_id with the wandb run id.
|
||||
# This is because we want to be able to resume the run from the wandb run id.
|
||||
cfg.wandb.run_id = run_id
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
@@ -115,26 +108,9 @@ class WandBLogger:
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(
|
||||
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
|
||||
):
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
# NOTE: This is not simple. Wandb step must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible. For example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
@@ -142,18 +118,7 @@ class WandBLogger:
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -87,8 +87,6 @@ class RecordControlConfig(ControlConfig):
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Reset follower arms to an initial position.
|
||||
reset_follower_arms: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
||||
@@ -34,10 +34,11 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
@@ -106,7 +107,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if self.dataset is not None and isinstance(self.dataset.repo_id, list):
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
|
||||
@@ -23,7 +23,6 @@ class FeatureType(str, Enum):
|
||||
VISUAL = "VISUAL"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
@@ -22,7 +22,7 @@ python -m lerobot.find_port
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -30,7 +30,7 @@ from pathlib import Path
|
||||
def find_available_ports():
|
||||
from serial.tools import list_ports # Part of pyserial library
|
||||
|
||||
if os.name == "nt": # Windows
|
||||
if platform.system() == "Windows":
|
||||
# List COM ports using pyserial
|
||||
ports = [port.device for port in list_ports.comports()]
|
||||
else: # Linux/macOS
|
||||
|
||||
@@ -92,7 +92,7 @@ class DatasetRecordConfig:
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
@@ -159,15 +159,15 @@ class RecordConfig:
|
||||
def record_loop(
|
||||
robot: Robot,
|
||||
events: dict,
|
||||
fps: int,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
control_time_s: int | None = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
):
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
# if policy is given it needs cleaning up
|
||||
@@ -186,7 +186,12 @@ def record_loop(
|
||||
|
||||
if policy is not None:
|
||||
action_values = predict_action(
|
||||
observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
observation_frame,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
action = {key: action_values[i] for i, key in enumerate(robot.action_features)}
|
||||
else:
|
||||
@@ -211,12 +216,8 @@ def record_loop(
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
# log_control_info(robot, dt_s, fps=fps)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
@@ -243,10 +244,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
)
|
||||
# for key, ft in dataset_features.items():
|
||||
# for property in ["dtype", "shape", "names"]:
|
||||
# if ft[property] != dataset.features[key][property]:
|
||||
# raise ValueError(ft)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
@@ -277,32 +274,16 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
# 3. place the cameras windows on screen
|
||||
# enable_teleoperation = policy is None
|
||||
# log_say("Warmup record", cfg.play_sounds)
|
||||
# record_loop(
|
||||
# robot=robot,
|
||||
# control_time_s=cfg.warmup_time_s,
|
||||
# display_data=cfg.display_data,
|
||||
# events=events,
|
||||
# fps=cfg.dataset.fps,
|
||||
# teleoperate=enable_teleoperation,
|
||||
# )
|
||||
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.dataset.fps)
|
||||
|
||||
for recorded_episodes in range(cfg.dataset.num_episodes):
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
fps=cfg.dataset.fps,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
@@ -316,9 +297,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
control_time_s=cfg.dataset.reset_time_s,
|
||||
fps=cfg.dataset.fps,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
)
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple script to control a robot from teleoperation.
|
||||
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.scripts.server.find_joint_limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
```
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.model.kinematics_utils import RobotKinematics
|
||||
from lerobot.common.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so100_follower_end_effector,
|
||||
)
|
||||
from lerobot.common.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
gamepad,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindJointLimitsConfig:
|
||||
teleop: TeleoperatorConfig
|
||||
robot: RobotConfig
|
||||
# Limit the maximum frames per second. By default, no limit.
|
||||
fps: int | None = None
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
|
||||
urdf_path: str = "/Users/michel_aractingi/code/SO-ARM100/Simulation/SO101/so101_new_calib.urdf"
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
teleop.connect()
|
||||
robot.connect()
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
ee_list = []
|
||||
pos_list = []
|
||||
kinematics = RobotKinematics(cfg.urdf_path)
|
||||
control_time_s = 10
|
||||
while True:
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
|
||||
joint_positions = robot.bus.sync_read("Present_Position")
|
||||
joint_positions = np.array([joint_positions[key] for key in joint_positions])
|
||||
ee_pos, _, _ = kinematics.forward_kinematics(joint_positions * np.pi / 180)
|
||||
ee_list.append(ee_pos.copy())
|
||||
pos_list.append(joint_positions)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
max_ee = np.max(np.stack(ee_list), 0)
|
||||
min_ee = np.min(np.stack(ee_list), 0)
|
||||
max_pos = np.max(np.stack(pos_list), 0)
|
||||
min_pos = np.min(np.stack(pos_list), 0)
|
||||
print(f"Max ee position {max_ee}")
|
||||
print(f"Min ee position {min_ee}")
|
||||
print(f"Max joint pos position {max_pos}")
|
||||
print(f"Min joint pos position {min_pos}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
find_joint_and_ee_bounds()
|
||||
@@ -1,726 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Actor server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the actor component of the distributed HILSerl architecture.
|
||||
It executes the policy in the robot environment, collects experience,
|
||||
and sends transitions to the learner server for policy updates.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
- Run with a specific robot type for a pick and place task:
|
||||
```bash
|
||||
python lerobot/scripts/rl/actor.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--robot.type=so100 \
|
||||
--task=pick_and_place
|
||||
```
|
||||
|
||||
- Set a custom workspace bound for the robot's end-effector:
|
||||
```bash
|
||||
python lerobot/scripts/rl/actor.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
|
||||
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
|
||||
```
|
||||
|
||||
- Run with specific camera crop parameters:
|
||||
```bash
|
||||
python lerobot/scripts/rl/actor.py \
|
||||
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
server is started before launching the actor.
|
||||
|
||||
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
5. Start this actor server with the same configuration
|
||||
6. Use human interventions to guide policy learning
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.cameras import opencv # noqa: F401
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robots import so100_follower_end_effector # noqa: F401
|
||||
from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401
|
||||
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.common.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.common.utils.process import setup_process_handlers
|
||||
from lerobot.common.utils.queue import get_last_item_from_queue
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
from lerobot.common.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
interactions_queue=interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg: Configuration settings for the interaction process.
|
||||
shutdown_event: Event to check if the process should shutdown.
|
||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||
transitions_queue: Queue to send transitions to the learner.
|
||||
interactions_queue: Queue to send interactions to the learner.
|
||||
"""
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor policy process logging initialized")
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
policy_timer = TimerManager("Policy inference", log=False)
|
||||
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(policy_timer)
|
||||
policy_timer.reset()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
if episode_total_steps > 0:
|
||||
intervention_rate = episode_intervention_steps / episode_total_steps
|
||||
|
||||
# Send episodic reward to the learner
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
"Intervention rate": intervention_rate,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
obs, info = online_env.reset()
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
)
|
||||
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def receive_policy(
|
||||
cfg: TrainPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(use_threads=False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
iterator = learner_client.StreamParameters(services_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def send_transitions(
|
||||
cfg: TrainPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor transitions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
def send_interactions(
|
||||
cfg: TrainPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor interactions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
setup_process_handlers(False)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> services_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event, # type: ignore
|
||||
interactions_queue: Queue,
|
||||
) -> services_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=5)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
services_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue)
|
||||
state_dict = bytes_to_state_dict(bytes_state_dict)
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
timer (TimerManager): The timer with collected metrics.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
if timer.count > 1:
|
||||
avg_fps = timer.fps_avg
|
||||
p90_fps = timer.fps_percentile(90)
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||
stats = {
|
||||
"Policy frequency [Hz]": avg_fps,
|
||||
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.env.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.actor == "threads"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
@@ -1,303 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
index_x, index_y = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal index_x, index_y, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
index_x, index_y = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
push_to_hub: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
# Create a copy of the frame to add to the new dataset
|
||||
new_frame = {}
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
if key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
|
||||
new_frame[key] = value
|
||||
|
||||
new_dataset.add_frame(new_frame)
|
||||
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
prev_episode_index = frame["episode_index"].item()
|
||||
|
||||
if push_to_hub:
|
||||
new_dataset.push_to_hub()
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,84 +0,0 @@
|
||||
import logging
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
"""
|
||||
Implementation of the LearnerService gRPC service
|
||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||
check transport.proto for the gRPC service definition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event, # type: ignore
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = self.parameters_queue.get()
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
services_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
@@ -47,7 +47,6 @@ from lerobot.common.robots import ( # noqa: F401
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so100_follower_end_effector,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.common.teleoperators import (
|
||||
@@ -55,25 +54,26 @@ from lerobot.common.teleoperators import (
|
||||
TeleoperatorConfig,
|
||||
make_teleoperator_from_config,
|
||||
)
|
||||
from lerobot.common.utils.robot_utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.common.utils.visualization_utils import _init_rerun
|
||||
|
||||
from .common.teleoperators import gamepad, koch_leader, so100_leader # noqa: F401
|
||||
from .common.teleoperators import koch_leader, so100_leader, so101_leader # noqa: F401
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeleoperateConfig:
|
||||
teleop: TeleoperatorConfig
|
||||
robot: RobotConfig
|
||||
# Limit the maximum frames per second. By default, no limit.
|
||||
fps: int | None = None
|
||||
# Limit the maximum frames per second.
|
||||
fps: int = 60
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
|
||||
|
||||
def teleop_loop(
|
||||
teleop: Teleoperator, robot: Robot, display_data: bool = False, duration: float | None = None
|
||||
teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None
|
||||
):
|
||||
display_len = max(len(key) for key in robot.action_features)
|
||||
start = time.perf_counter()
|
||||
@@ -92,6 +92,9 @@ def teleop_loop(
|
||||
rr.log(f"action_{act}", rr.Scalar(val))
|
||||
|
||||
robot.send_action(action)
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
@@ -120,7 +123,7 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
teleop_loop(teleop, robot, display_data=cfg.display_data, duration=cfg.teleop_time_s)
|
||||
teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
|
||||
@@ -53,7 +53,7 @@ dependencies = [
|
||||
"einops>=0.8.0",
|
||||
"flask>=3.0.3",
|
||||
"gdown>=5.1.0",
|
||||
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
"h5py>=3.10.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
||||
"imageio[ffmpeg]>=2.34.0",
|
||||
@@ -82,57 +82,32 @@ dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
|
||||
dora = [
|
||||
"gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
|
||||
]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
"pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pynput>=1.7.7",
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
|
||||
]
|
||||
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"]
|
||||
test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
||||
umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
|
||||
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
"tests/data",
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
"*_pb2.py",
|
||||
"*_pb2_grpc.py",
|
||||
]
|
||||
|
||||
exclude = ["tests/artifacts/**/*.safetensors"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
@@ -141,9 +141,7 @@ def test_async_read_timeout():
|
||||
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(
|
||||
timeout_ms=0
|
||||
) # NOTE(Steven): This is flaky as sdometimes we actually get a frame
|
||||
camera.async_read(timeout_ms=0)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
@@ -162,9 +162,7 @@ def test_async_read_timeout():
|
||||
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(
|
||||
timeout_ms=0
|
||||
) # NOTE(Steven): This is flaky as sdometimes we actually get a frame
|
||||
camera.async_read(timeout_ms=0)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
@@ -14,15 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras
|
||||
from tests.utils import DEVICE, make_camera
|
||||
from tests.utils import DEVICE
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
@@ -65,11 +62,6 @@ def _check_component_availability(component_type, available_components, make_com
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_camera_available(camera_type):
|
||||
return _check_component_availability(camera_type, available_cameras, make_camera)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_builtins_input(monkeypatch):
|
||||
def print_text(text=None):
|
||||
@@ -77,19 +69,3 @@ def patch_builtins_input(monkeypatch):
|
||||
print(text)
|
||||
|
||||
monkeypatch.setattr("builtins.input", print_text)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--seed",
|
||||
action="store",
|
||||
default="42",
|
||||
help="Set random seed for reproducibility",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed(request):
|
||||
seed = int(request.config.getoption("--seed"))
|
||||
random.seed(seed) # Python random
|
||||
torch.manual_seed(seed) # PyTorch
|
||||
|
||||
@@ -48,7 +48,7 @@ DXL_CRC_TABLE = [
|
||||
|
||||
class MockDynamixelPacketv2(abc.ABC):
|
||||
@classmethod
|
||||
def build(cls, dxl_id: int, params: list[int], length: list[int], *args, **kwargs) -> bytes:
|
||||
def build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> bytes:
|
||||
packet = cls._build(dxl_id, params, length, *args, **kwargs)
|
||||
packet = cls._add_stuffing(packet)
|
||||
packet = cls._add_crc(packet)
|
||||
@@ -281,7 +281,7 @@ class MockInstructionPacket(MockDynamixelPacketv2):
|
||||
@classmethod
|
||||
def sync_write(
|
||||
cls,
|
||||
ids_values: dict[int],
|
||||
ids_values: dict[int, int],
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
|
||||
@@ -151,7 +151,7 @@ class MockInstructionPacket(MockFeetechPacket):
|
||||
@classmethod
|
||||
def sync_write(
|
||||
cls,
|
||||
ids_values: dict[int],
|
||||
ids_values: dict[int, int],
|
||||
start_address: int,
|
||||
data_length: int,
|
||||
) -> bytes:
|
||||
@@ -219,7 +219,7 @@ class MockStatusPacket(MockFeetechPacket):
|
||||
|
||||
Args:
|
||||
scs_id (int): ID of the servo responding.
|
||||
error (str, optional): Error to be returned. Defaults to "Success".
|
||||
error (int, optional): Error to be returned. Defaults to 0 (success).
|
||||
|
||||
Returns:
|
||||
bytes: The raw 'Ping' status packet ready to be sent through serial.
|
||||
|
||||
@@ -21,7 +21,6 @@ from lerobot.common.constants import (
|
||||
from lerobot.common.optim.optimizers import (
|
||||
AdamConfig,
|
||||
AdamWConfig,
|
||||
MultiAdamConfig,
|
||||
SGDConfig,
|
||||
load_optimizer_state,
|
||||
save_optimizer_state,
|
||||
@@ -34,21 +33,13 @@ from lerobot.common.optim.optimizers import (
|
||||
(AdamConfig, torch.optim.Adam),
|
||||
(AdamWConfig, torch.optim.AdamW),
|
||||
(SGDConfig, torch.optim.SGD),
|
||||
(MultiAdamConfig, dict),
|
||||
],
|
||||
)
|
||||
def test_optimizer_build(config_cls, expected_class, model_params):
|
||||
config = config_cls()
|
||||
if config_cls == MultiAdamConfig:
|
||||
params_dict = {"default": model_params}
|
||||
optimizer = config.build(params_dict)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert isinstance(optimizer["default"], torch.optim.Adam)
|
||||
assert optimizer["default"].defaults["lr"] == config.lr
|
||||
else:
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
|
||||
|
||||
def test_save_optimizer_state(optimizer, tmp_path):
|
||||
@@ -63,180 +54,3 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
||||
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
||||
|
||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_params_dict():
|
||||
return {
|
||||
"actor": [torch.nn.Parameter(torch.randn(10, 10))],
|
||||
"critic": [torch.nn.Parameter(torch.randn(5, 5))],
|
||||
"temperature": [torch.nn.Parameter(torch.randn(3, 3))],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_params, expected_values",
|
||||
[
|
||||
# Test 1: Basic configuration with different learning rates
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 2: Different weight decays and beta values
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6},
|
||||
"temperature": {"lr": 2e-3, "betas": (0.95, 0.999)},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)},
|
||||
},
|
||||
),
|
||||
# Test 3: Epsilon parameter customization
|
||||
(
|
||||
{
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-4,
|
||||
"optimizer_groups": {
|
||||
"actor": {"lr": 1e-4, "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "eps": 1e-8},
|
||||
},
|
||||
},
|
||||
{
|
||||
"actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6},
|
||||
"critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7},
|
||||
"temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_adam_configuration(base_params_dict, config_params, expected_values):
|
||||
# Create config with the given parameters
|
||||
config = MultiAdamConfig(**config_params)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Verify optimizer count and keys
|
||||
assert len(optimizers) == len(expected_values)
|
||||
assert set(optimizers.keys()) == set(expected_values.keys())
|
||||
|
||||
# Check that all optimizers are Adam instances
|
||||
for opt in optimizers.values():
|
||||
assert isinstance(opt, torch.optim.Adam)
|
||||
|
||||
# Verify hyperparameters for each optimizer
|
||||
for name, expected in expected_values.items():
|
||||
optimizer = optimizers[name]
|
||||
for param, value in expected.items():
|
||||
assert optimizer.defaults[param] == value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_optimizers(base_params_dict):
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
return config.build(base_params_dict)
|
||||
|
||||
|
||||
def test_save_multi_optimizer_state(multi_optimizers, tmp_path):
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Verify that directories were created for each optimizer
|
||||
for name in multi_optimizers:
|
||||
assert (tmp_path / name).is_dir()
|
||||
assert (tmp_path / name / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
|
||||
|
||||
def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path):
|
||||
# Option 1: Add a minimal backward pass to populate optimizer states
|
||||
for name, params in base_params_dict.items():
|
||||
if name in multi_optimizers:
|
||||
# Create a dummy loss and do backward
|
||||
dummy_loss = params[0].sum()
|
||||
dummy_loss.backward()
|
||||
# Perform an optimization step
|
||||
multi_optimizers[name].step()
|
||||
# Zero gradients for next steps
|
||||
multi_optimizers[name].zero_grad()
|
||||
|
||||
# Save optimizer states
|
||||
save_optimizer_state(multi_optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify state dictionaries match
|
||||
for name in multi_optimizers:
|
||||
torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict())
|
||||
|
||||
|
||||
def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path):
|
||||
"""Test saving and loading optimizer states even when the state is empty (no backward pass)."""
|
||||
# Create config and build optimizers
|
||||
config = MultiAdamConfig(
|
||||
lr=1e-3,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 1e-4},
|
||||
"critic": {"lr": 5e-4},
|
||||
"temperature": {"lr": 2e-3},
|
||||
},
|
||||
)
|
||||
optimizers = config.build(base_params_dict)
|
||||
|
||||
# Save optimizer states without any backward pass (empty state)
|
||||
save_optimizer_state(optimizers, tmp_path)
|
||||
|
||||
# Create new optimizers with the same config
|
||||
new_optimizers = config.build(base_params_dict)
|
||||
|
||||
# Load optimizer states
|
||||
loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path)
|
||||
|
||||
# Verify hyperparameters match even with empty state
|
||||
for name, optimizer in optimizers.items():
|
||||
assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"]
|
||||
assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"]
|
||||
assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"]
|
||||
|
||||
# Verify state dictionaries match (they will be empty)
|
||||
torch.testing.assert_close(
|
||||
optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"]
|
||||
)
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]),
|
||||
probabilities=torch.tensor([0.1, 0.2, 0.3]),
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
config.num_cameras = 1
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.size() == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
config.num_classes = num_classes
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size, num_classes])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
@@ -1,217 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
CriticNetworkConfig,
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
config = SACConfig()
|
||||
|
||||
assert config.normalization_mapping == {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
assert config.dataset_stats == {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
|
||||
# Basic parameters
|
||||
assert config.device == "cpu"
|
||||
assert config.storage_device == "cpu"
|
||||
assert config.discount == 0.99
|
||||
assert config.temperature_init == 1.0
|
||||
assert config.num_critics == 2
|
||||
|
||||
# Architecture specifics
|
||||
assert config.vision_encoder_name is None
|
||||
assert config.freeze_vision_encoder is True
|
||||
assert config.image_encoder_hidden_dim == 32
|
||||
assert config.shared_encoder is True
|
||||
assert config.num_discrete_actions is None
|
||||
assert config.image_embedding_pooling_dim == 8
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
assert config.online_env_seed == 10000
|
||||
assert config.online_buffer_capacity == 100000
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
assert config.online_step_before_learning == 100
|
||||
assert config.policy_update_freq == 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
assert config.num_subsample_critics is None
|
||||
assert config.critic_lr == 3e-4
|
||||
assert config.actor_lr == 3e-4
|
||||
assert config.temperature_lr == 3e-4
|
||||
assert config.critic_target_update_weight == 0.005
|
||||
assert config.utd_ratio == 1
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
assert config.target_entropy is None
|
||||
assert config.use_backup_entropy is True
|
||||
assert config.grad_clip_norm == 40.0
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
assert config.dataset_stats == expected_dataset_stats
|
||||
|
||||
# Critic network configuration
|
||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.critic_network_kwargs.activate_final is True
|
||||
assert config.critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor network configuration
|
||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.actor_network_kwargs.activate_final is True
|
||||
|
||||
# Policy configuration
|
||||
assert config.policy_kwargs.use_tanh_squash is True
|
||||
assert config.policy_kwargs.log_std_min == 1e-5
|
||||
assert config.policy_kwargs.log_std_max == 10.0
|
||||
assert config.policy_kwargs.init_final == 0.05
|
||||
|
||||
# Discrete critic network configuration
|
||||
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.discrete_critic_network_kwargs.activate_final is True
|
||||
assert config.discrete_critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor learner configuration
|
||||
assert config.actor_learner_config.learner_host == "127.0.0.1"
|
||||
assert config.actor_learner_config.learner_port == 50051
|
||||
assert config.actor_learner_config.policy_parameters_push_frequency == 4
|
||||
|
||||
# Concurrency configuration
|
||||
assert config.concurrency.actor == "threads"
|
||||
assert config.concurrency.learner == "threads"
|
||||
|
||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||
|
||||
|
||||
def test_critic_network_kwargs():
|
||||
config = CriticNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
assert config.final_activation is None
|
||||
|
||||
|
||||
def test_actor_network_kwargs():
|
||||
config = ActorNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
|
||||
|
||||
def test_policy_kwargs():
|
||||
config = PolicyConfig()
|
||||
assert config.use_tanh_squash is True
|
||||
assert config.log_std_min == 1e-5
|
||||
assert config.log_std_max == 10.0
|
||||
assert config.init_final == 0.05
|
||||
|
||||
|
||||
def test_actor_learner_config():
|
||||
config = ActorLearnerConfig()
|
||||
assert config.learner_host == "127.0.0.1"
|
||||
assert config.learner_port == 50051
|
||||
assert config.policy_parameters_push_frequency == 4
|
||||
|
||||
|
||||
def test_concurrency_config():
|
||||
config = ConcurrencyConfig()
|
||||
assert config.actor == "threads"
|
||||
assert config.learner == "threads"
|
||||
|
||||
|
||||
def test_sac_config_custom_initialization():
|
||||
config = SACConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
num_critics=3,
|
||||
)
|
||||
|
||||
assert config.device == "cpu"
|
||||
assert config.discount == 0.95
|
||||
assert config.temperature_init == 0.5
|
||||
assert config.num_critics == 3
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
||||
):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
||||
config.validate_features()
|
||||
@@ -1,519 +0,0 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
def test_mlp_with_default_args():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
x = torch.randn(10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (256,)
|
||||
|
||||
|
||||
def test_mlp_with_batch_dim():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
x = torch.randn(2, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (2, 256)
|
||||
|
||||
|
||||
def test_forward_with_empty_hidden_dims():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
||||
x = torch.randn(1, 10)
|
||||
assert mlp(x).shape == (1, 10)
|
||||
|
||||
|
||||
def test_mlp_with_dropout():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 11)
|
||||
|
||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
||||
assert drop_out_layers_count == 2
|
||||
|
||||
|
||||
def test_mlp_with_custom_final_activation():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 256)
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
|
||||
|
||||
def test_sac_policy_with_default_args():
|
||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
||||
SACPolicy()
|
||||
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
|
||||
return torch.randn(batch_size, action_dim)
|
||||
|
||||
|
||||
def create_default_train_batch(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_state(batch_size, state_dim),
|
||||
"next_state": create_dummy_state(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_train_batch_with_visual_input(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
"action": create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
action_dim = continuous_action_dim
|
||||
if has_discrete_action:
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
"observation.state": {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0] * continuous_action_dim,
|
||||
"max": [1.0] * continuous_action_dim,
|
||||
},
|
||||
},
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def create_config_with_visual_input(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
config = create_default_config(
|
||||
state_dim=state_dim,
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats["observation.image"] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
assert discrete_critic_loss.shape == ()
|
||||
discrete_critic_loss.backward()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
|
||||
|
||||
def test_sac_policy_with_default_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -3.0
|
||||
|
||||
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
policy.update_temperature()
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_policy_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
@@ -45,12 +45,7 @@ def test_available_policies():
|
||||
This test verifies that the class attribute `name` for all policies is
|
||||
consistent with those listed in `lerobot/__init__.py`.
|
||||
"""
|
||||
policy_classes = [
|
||||
ACTPolicy,
|
||||
DiffusionPolicy,
|
||||
TDMPCPolicy,
|
||||
VQBeTPolicy,
|
||||
]
|
||||
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
|
||||
policies = [pol_cls.name for pol_cls in policy_classes]
|
||||
assert set(policies) == set(lerobot.available_policies), policies
|
||||
|
||||
|
||||
@@ -21,9 +21,6 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.cameras import Camera
|
||||
from lerobot.common.motors.motors_bus import MotorsBus
|
||||
from lerobot.common.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
@@ -185,63 +182,3 @@ def require_package(package_name):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_camera(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Access the pytest request context to get the is_camera_available fixture
|
||||
request = kwargs.get("request")
|
||||
camera_type = kwargs.get("camera_type")
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
if camera_type is None:
|
||||
raise ValueError("The 'camera_type' must be an argument of the test function.")
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
|
||||
if not mock and not request.getfixturevalue("is_camera_available"):
|
||||
pytest.skip(f"A {camera_type} camera is not available.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_camera(camera_type: str, **kwargs) -> Camera:
|
||||
if camera_type == "opencv":
|
||||
camera_index = kwargs.pop("camera_index", OPENCV_CAMERA_INDEX)
|
||||
kwargs["camera_index"] = camera_index
|
||||
from lerobot.common.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
|
||||
config = OpenCVCameraConfig(**kwargs)
|
||||
return OpenCVCamera(config)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
serial_number = kwargs.pop("serial_number", INTELREALSENSE_SERIAL_NUMBER)
|
||||
kwargs["serial_number"] = serial_number
|
||||
from lerobot.common.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
|
||||
|
||||
config = RealSenseCameraConfig(**kwargs)
|
||||
return RealSenseCamera(config)
|
||||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
port = kwargs.pop("port", DYNAMIXEL_PORT)
|
||||
motors = kwargs.pop("motors", DYNAMIXEL_MOTORS)
|
||||
return make_motors_bus_device(motor_type, port=port, motors=motors, **kwargs)
|
||||
|
||||
elif motor_type == "feetech":
|
||||
port = kwargs.pop("port", FEETECH_PORT)
|
||||
motors = kwargs.pop("motors", FEETECH_MOTORS)
|
||||
return make_motors_bus_device(motor_type, port=port, motors=motors, **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
@@ -1,599 +0,0 @@
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
return ["observation.image", "observation.state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def replay_buffer() -> ReplayBuffer:
|
||||
return create_empty_replay_buffer()
|
||||
|
||||
|
||||
def clone_state(state: dict) -> dict:
|
||||
return {k: v.clone() for k, v in state.items()}
|
||||
|
||||
|
||||
def create_empty_replay_buffer(
|
||||
optimize_memory: bool = False,
|
||||
use_drq: bool = False,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
) -> ReplayBuffer:
|
||||
buffer_capacity = 10
|
||||
device = "cpu"
|
||||
return ReplayBuffer(
|
||||
buffer_capacity,
|
||||
device,
|
||||
state_dims(),
|
||||
optimize_memory=optimize_memory,
|
||||
use_drq=use_drq,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
)
|
||||
|
||||
|
||||
def create_random_image() -> torch.Tensor:
|
||||
return torch.rand(3, 84, 84)
|
||||
|
||||
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
"truncated": torch.tensor(False),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
|
||||
def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]:
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer)
|
||||
|
||||
|
||||
def create_dummy_state() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tensor_memory_consumption(tensor):
|
||||
return tensor.nelement() * tensor.element_size()
|
||||
|
||||
|
||||
def get_tensors_memory_consumption(obj, visited_addresses):
|
||||
total_size = 0
|
||||
|
||||
address = id(obj)
|
||||
if address in visited_addresses:
|
||||
return 0
|
||||
|
||||
visited_addresses.add(address)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return get_tensor_memory_consumption(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
total_size += get_tensors_memory_consumption(item, visited_addresses)
|
||||
elif isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
total_size += get_tensors_memory_consumption(value, visited_addresses)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# It's an object, we need to get the size of the attributes
|
||||
for _, attr in vars(obj).items():
|
||||
total_size += get_tensors_memory_consumption(attr, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def get_object_memory(obj):
|
||||
# Track visited addresses to avoid infinite loops
|
||||
# and cases when two properties point to the same object
|
||||
visited_addresses = set()
|
||||
|
||||
# Get the size of the object in bytes
|
||||
total_size = sys.getsizeof(obj)
|
||||
|
||||
# Get the size of the tensor attributes
|
||||
total_size += get_tensors_memory_consumption(obj, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def create_dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def dict_properties() -> list:
|
||||
return ["state", "next_state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def next_dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def test_empty_buffer_sample_raises_error(replay_buffer):
|
||||
assert len(replay_buffer) == 0, "Replay buffer should be empty."
|
||||
assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10."
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_zero_capacity_buffer_raises_error():
|
||||
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
|
||||
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
|
||||
|
||||
|
||||
def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding."
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action), (
|
||||
"Action should be equal to the first transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
|
||||
def test_add_over_capacity():
|
||||
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
|
||||
assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action_3), (
|
||||
"Action should be equal to the last transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
||||
|
||||
|
||||
def test_sample_from_empty_buffer(replay_buffer):
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(1)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=clone_state(dummy_state),
|
||||
action=dummy_action.clone(),
|
||||
reward=1.0,
|
||||
next_state=clone_state(next_dummy_state),
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k, v in expected_batch_transition[buffer_property].items():
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
assert got_state.device.type == "cpu", f"{k} should be on cpu."
|
||||
|
||||
assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition."
|
||||
|
||||
for key, _value in expected_batch_transition.items():
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
|
||||
v_tensor = expected_batch_transition[key]
|
||||
if not isinstance(v_tensor, torch.Tensor):
|
||||
v_tensor = torch.tensor(v_tensor)
|
||||
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
assert got_value.device.type == "cpu", f"{key} should be on cpu."
|
||||
assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition."
|
||||
|
||||
|
||||
def test_sample_with_batch_bigger_than_buffer_size(
|
||||
replay_buffer, dummy_state, next_dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(10)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=dummy_state,
|
||||
action=dummy_action,
|
||||
reward=1.0,
|
||||
next_state=next_dummy_state,
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in expected_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
|
||||
for key in expected_batch_transition:
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
|
||||
|
||||
def test_sample_batch(replay_buffer):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True)
|
||||
|
||||
dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4]
|
||||
dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4]
|
||||
|
||||
got_batch_transition = replay_buffer.sample(3)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in got_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
for got_state_item in got_state:
|
||||
assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), (
|
||||
f"{k} should be equal to one of the dummy states."
|
||||
)
|
||||
|
||||
for got_action_item in got_batch_transition["action"]:
|
||||
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
|
||||
"Actions should be equal to the dummy actions."
|
||||
)
|
||||
|
||||
for k in got_batch_transition:
|
||||
if k in dict_properties() or k == "complementary_info":
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[k]
|
||||
assert got_value.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
|
||||
def test_to_lerobot_dataset_with_empty_buffer(replay_buffer):
|
||||
with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."):
|
||||
replay_buffer.to_lerobot_dataset("dummy_repo")
|
||||
|
||||
|
||||
def test_to_lerobot_dataset(tmp_path):
|
||||
ds, buffer = create_dataset_from_replay_buffer(tmp_path)
|
||||
|
||||
assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer"
|
||||
assert ds.fps == 1, "FPS should be 1"
|
||||
assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id"
|
||||
|
||||
for dim in state_dims():
|
||||
assert dim in ds.features
|
||||
assert ds.features[dim]["shape"] == buffer.states[dim][0].shape
|
||||
|
||||
assert ds.num_episodes == 2
|
||||
assert ds.num_frames == 4
|
||||
|
||||
for j, value in enumerate(ds):
|
||||
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == "action":
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
# Tenssor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
assert torch.equal(value, buffer.states["observation.state"][i])
|
||||
|
||||
|
||||
def test_from_lerobot_dataset(tmp_path):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root)
|
||||
|
||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
||||
)
|
||||
|
||||
# Check only the part of the buffer that's actually filled with data
|
||||
assert torch.equal(
|
||||
reconverted_buffer.actions[: len(replay_buffer)],
|
||||
replay_buffer.actions[: len(replay_buffer)],
|
||||
), "Actions from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][i],
|
||||
reconverted_buffer.states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
# The 2, 3 frames have done flag, so their values will be equal to the current state
|
||||
for i in range(2):
|
||||
# In the current implementation we take the next state from the `states` and ignore `next_states`
|
||||
next_index = (i + 1) % 4
|
||||
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][next_index],
|
||||
reconverted_buffer.next_states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
for i in range(2, 4):
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][i],
|
||||
reconverted_buffer.next_states["observation.state"][i],
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_sample_alignment():
|
||||
# Initialize buffer
|
||||
buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu")
|
||||
|
||||
# Fill buffer with patterned data
|
||||
for i in range(100):
|
||||
signature = float(i) / 100.0
|
||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
action = torch.tensor([[2.0 * signature]]).float()
|
||||
reward = 3.0 * signature
|
||||
|
||||
is_end = (i + 1) % 10 == 0
|
||||
if is_end:
|
||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
done = True
|
||||
else:
|
||||
next_signature = float(i + 1) / 100.0
|
||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
||||
done = False
|
||||
|
||||
buffer.add(state, action, reward, next_state, done, False)
|
||||
|
||||
# Sample and verify
|
||||
batch = buffer.sample(50)
|
||||
|
||||
for i in range(50):
|
||||
state_sig = batch["state"]["state_value"][i].item()
|
||||
action_val = batch["action"][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
|
||||
# Verify relationships
|
||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||
f"Action {action_val} should be 2x state signature {state_sig}"
|
||||
)
|
||||
|
||||
assert abs(reward_val - 3.0 * state_sig) < 1e-4, (
|
||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||
)
|
||||
|
||||
if is_done:
|
||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
)
|
||||
else:
|
||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||
valid_next = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
assert valid_next, (
|
||||
f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}"
|
||||
)
|
||||
|
||||
|
||||
def test_memory_optimization():
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
optimized_replay_buffer = create_empty_replay_buffer(True)
|
||||
optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True)
|
||||
|
||||
assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), (
|
||||
"Optimized replay buffer should be smaller than the original replay buffer"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action):
|
||||
def dummy_image_augmentation_function(x):
|
||||
return torch.ones_like(x) * 10
|
||||
|
||||
replay_buffer = create_empty_replay_buffer(
|
||||
use_drq=True, image_augmentation_function=dummy_image_augmentation_function
|
||||
)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_default_image_augmentation_function(
|
||||
dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer = create_empty_replay_buffer(use_drq=True)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
# Let's check that it doesn't fail and shapes are correct
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_basic():
|
||||
# Create a batch of 2 images with known patterns
|
||||
batch_size, channels, height, width = 2, 3, 10, 8
|
||||
images = torch.zeros((batch_size, channels, height, width))
|
||||
|
||||
# Fill with unique values for testing
|
||||
for b in range(batch_size):
|
||||
images[b] = b + 1
|
||||
|
||||
crop_size = (6, 4) # Smaller than original
|
||||
cropped = random_crop_vectorized(images, crop_size)
|
||||
|
||||
# Check output shape
|
||||
assert cropped.shape == (batch_size, channels, *crop_size)
|
||||
|
||||
# Check that values are preserved (should be either 1s or 2s for respective batches)
|
||||
assert torch.all(cropped[0] == 1)
|
||||
assert torch.all(cropped[1] == 2)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_invalid_size():
|
||||
images = torch.zeros((2, 3, 10, 8))
|
||||
|
||||
# Test crop size larger than image
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (12, 8))
|
||||
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (10, 10))
|
||||
Reference in New Issue
Block a user