Compare commits

...

29 Commits

Author SHA1 Message Date
Michel Aractingi
06fc9b89e1 pass entire config to make_optimizer 2024-09-02 08:20:17 +00:00
Michel Aractingi
3034272229 modified tests dirs 2024-09-02 08:04:56 +00:00
Michel Aractingi
bbce0eaeaf moved make optimizer and scheduler function to inside policy 2024-09-02 07:53:10 +00:00
Kenneth Gerald Hamilton
c0da806232 repair mailto link (#397) 2024-09-01 00:11:39 +02:00
Mishig
114e09f570 rm EpisodeSampler from viz (#389) 2024-08-30 10:53:55 +02:00
Simon Alibert
04a995e7d1 Fix safe_action (#395) 2024-08-30 10:36:05 +02:00
Michel Aractingi
4806336816 Add the possibility to visualize language instructions in visualize_dataset_html.py (#388)
Co-authored-by: Mishig <dmishig@gmail.com>
2024-08-28 11:50:31 +02:00
Remi
1ce418e4a1 Add koch bimanual (#385) 2024-08-28 00:53:31 +02:00
Michel Aractingi
eb4c505cff Support for converting OpenX datasets from RLDS format to LeRobotDataset (#354)
Signed-off-by: youliangtan <tan_you_liang@hotmail.com>
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: youliangtan <tan_you_liang@hotmail.com>
Co-authored-by: Remi <re.cadene@gmail.com>
2024-08-27 09:07:00 +02:00
Mishig
aad59e6b6b Fix videos in visualize_dataset are not in sync (#382) 2024-08-26 17:38:48 +02:00
Alexander Soare
9ce98bb93c Add safety limits on relative action target (#373) 2024-08-26 14:30:18 +01:00
Alexander Soare
97086cdcdf Make gripper_open_degree a config param (#379) 2024-08-26 12:28:16 +01:00
Alexander Soare
9c7649f140 Make sure init_hydra_config does not require any keys (#376) 2024-08-23 12:27:08 +01:00
Zhuoheng Li
a2592a5563 Provide more information to the user (#358)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Remi <re.cadene@gmail.com>
2024-08-23 11:00:35 +01:00
ellacroix
b5ad79a7d3 Fix typo in tutorial (#371) 2024-08-21 14:14:01 +02:00
Remi
996468bcce Update README.md 2024-08-20 16:45:57 +02:00
Remi
f98200297d Slightly improve tutorial and README (#370)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2024-08-20 16:41:39 +02:00
NielsRogge
86bbd16d43 Improve discoverability on the hub (#325)
Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2024-08-19 15:16:46 +02:00
Alexander Soare
0f6e0f6d74 Fix input dim (#365) 2024-08-19 11:42:32 +01:00
Remi
fc3e545e03 Update README.md 2024-08-19 11:14:10 +02:00
Simon Alibert
b98ea415c1 Add dataset cards (#363) 2024-08-16 10:08:44 +02:00
Remi
bbe9057225 Improve control robot ; Add process to configure motor indices (#326)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: jess-moss <jess.moss@dextrousrobotics.com>
Co-authored-by: Marina Barannikov <marina.barannikov@huggingface.co>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-08-15 18:11:33 +02:00
Alexander Soare
8c4643687c fix bug in example 2 (#361) 2024-08-15 13:59:47 +01:00
Julien Perez
fab037f78d feat for the GPU poors : Add GPU availability check in evaluate_pretr… (#359)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-08-13 16:03:05 +01:00
Simon Alibert
03d647269e Fix CI builds (#357) 2024-08-12 17:57:03 +02:00
Remi
2252b42337 Add visualize_dataset_html with http.server (#188) 2024-08-08 20:19:06 +03:00
Adrien
bc6384bb80 fix ci (#351)
Signed-off-by: Adrien <adrien@huggingface.co>
2024-08-05 16:12:26 +02:00
resolver101757
8df7e63d61 Update README for cross-platform installation compatibility (#347) 2024-07-30 00:48:41 +02:00
Halvard Bariller
7a3cb1ad34 Adjust the timestamps' description in Diffusion Policy (#343)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-07-26 12:47:03 +01:00
68 changed files with 5641 additions and 699 deletions

View File

@@ -14,20 +14,14 @@ env:
jobs:
latest-cpu:
name: CPU
runs-on: ubuntu-latest
runs-on:
group: aws-general-8-plus
steps:
- name: Cleanup disk
- name: Install Git LFS
run: |
sudo df -h
# sudo ls -l /usr/local/lib/
# sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo df -h
sudo apt-get update
sudo apt-get install git-lfs
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -55,20 +49,15 @@ jobs:
latest-cuda:
name: GPU
runs-on: ubuntu-latest
runs-on:
group: aws-general-8-plus
steps:
- name: Cleanup disk
- name: Install Git LFS
run: |
sudo df -h
# sudo ls -l /usr/local/lib/
# sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo df -h
sudo apt-get update
sudo apt-get install git-lfs
git lfs install
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -95,20 +84,9 @@ jobs:
latest-cuda-dev:
name: GPU Dev
runs-on: ubuntu-latest
runs-on:
group: aws-general-8-plus
steps:
- name: Cleanup disk
run: |
sudo df -h
# sudo ls -l /usr/local/lib/
# sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -16,7 +16,8 @@ jobs:
name: CPU
strategy:
fail-fast: false
runs-on: ubuntu-latest
runs-on:
group: aws-general-8-plus
container:
image: huggingface/lerobot-cpu:latest
options: --shm-size "16gb"
@@ -43,7 +44,8 @@ jobs:
name: GPU
strategy:
fail-fast: false
runs-on: [single-gpu, nvidia-gpu, t4, ci]
runs-on:
group: aws-g6-4xlarge-plus
env:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu"

View File

@@ -42,26 +42,14 @@ jobs:
build_modified_dockerfiles:
name: Build modified Docker images
needs: get_changed_files
runs-on: ubuntu-latest
runs-on:
group: aws-general-8-plus
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
strategy:
fail-fast: false
matrix:
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
steps:
- name: Cleanup disk
run: |
sudo df -h
# sudo ls -l /usr/local/lib/
# sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

1
.gitignore vendored
View File

@@ -121,6 +121,7 @@ celerybeat.pid
# Environments
.env
.venv
env/
venv/
env.bak/
venv.bak/

View File

@@ -20,7 +20,7 @@ Some of the ways you can contribute to 🤗 LeRobot:
* Contributing to the examples or to the documentation.
* Submitting issues related to bugs or desired new features.
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co).
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)

View File

@@ -22,8 +22,22 @@
</div>
<h2 align="center">
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">Hot new tutorial: Getting started with real-world robots</a></p>
</h2>
<div align="center">
<img src="media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
<p>We just dropped an in-depth tutorial on how to build your own robot!</p>
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
<p>Then watch your homemade robot act autonomously 🤯</p>
<p>For more info, see <a href="https://x.com/RemiCadene/status/1825455895561859185">our thread on X</a> or <a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">our tutorial page</a>.</p>
</div>
<br/>
<h3 align="center">
<p>State-of-the-art Machine Learning for real-world robotics</p>
<p>LeRobot: State-of-the-art AI for real-world robotics</p>
</h3>
---
@@ -65,17 +79,19 @@
Download our source code:
```bash
git clone https://github.com/huggingface/lerobot.git && cd lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
```
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
conda create -y -n lerobot python=3.10
conda activate lerobot
```
Install 🤗 LeRobot:
```bash
pip install .
pip install -e .
```
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
@@ -89,7 +105,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
For instance, to install 🤗 LeRobot with aloha and pusht, use:
```bash
pip install ".[aloha, pusht]"
pip install -e ".[aloha, pusht]"
```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
@@ -114,10 +130,12 @@ wandb login
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
| | ├── envs # various sim environments: aloha, pusht, xarm
| | ├── policies # various policies: act, diffusion, tdmpc
| | ├── robot_devices # various real devices: dynamixel motors, opencv cameras, koch robots
| | └── utils # various utilities
| └── scripts # contains functions to execute via command line
| ├── eval.py # load policy and evaluate it on an environment
| ├── train.py # train a policy via imitation learning and/or reinforcement learning
| ├── control_robot.py # teleoperate a real robot, record data, run a policy
| ├── push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
| └── visualize_dataset.py # load a dataset and render its demonstrations
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
@@ -249,13 +267,20 @@ checkpoints
│ └── training_state.pth # optimizer/scheduler/rng state and training step
```
To resume training from a checkpoint, you can add these to the `train.py` python command:
```bash
hydra.run.dir=your/original/experiment/dir resume=true
```
It will load the pretrained model, optimizer and scheduler states for training. For more information please see our tutorial on training resumption [here](https://github.com/huggingface/lerobot/blob/main/examples/5_resume_training.md).
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
```bash
wandb.enable=true
```
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explaination of some commonly used metrics in logs.
![](media/wandb.png)

View File

@@ -9,6 +9,7 @@ ARG DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Create virtual environment

View File

@@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
sed gawk grep curl wget zip unzip \
tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
speech-dispatcher \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@@ -9,6 +9,7 @@ ARG DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher \
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@@ -18,8 +18,6 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda")
# Download the diffusion policy for pusht environment
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
@@ -27,6 +25,17 @@ pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()
# Check if GPU is available
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU is available. Device set to:", device)
else:
device = torch.device("cpu")
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
policy.diffusion.num_inference_steps = 10
policy.to(device)
# Initialize evaluation environment to render two observation types:

View File

@@ -170,6 +170,36 @@ python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpo
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
## Typical logs and metrics
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
After that, you will see training log like this one:
```
INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
```
or evaluation log like:
```
INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
```
These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:
- `smpl`: number of samples seen during training.
- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
- `epch`: number of time all unique samples are seen (epoch).
- `grdn`: gradient norm.
- `∑rwrd`: compute the sum of rewards in every evaluation episode and then take an average of them.
- `success`: average success rate of eval episodes. Reward and success are usually different except for the sparsing reward setting, where reward=1 only when the task is completed successfully.
- `eval_s`: time to evaluate the policy in the environment, in second.
- `updt_s`: time to update the network parameters, in second.
- `data_s`: time to load a batch of data, in second.
Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.
---
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):

File diff suppressed because it is too large Load Diff

View File

@@ -129,6 +129,53 @@ available_real_world_datasets = [
"lerobot/unitreeh1_rearrange_objects",
"lerobot/unitreeh1_two_robot_greeting",
"lerobot/unitreeh1_warehouse",
"lerobot/nyu_rot_dataset",
"lerobot/utokyo_saytap",
"lerobot/imperialcollege_sawyer_wrist_cam",
"lerobot/utokyo_xarm_bimanual",
"lerobot/tokyo_u_lsmo",
"lerobot/utokyo_pr2_opening_fridge",
"lerobot/cmu_franka_exploration_dataset",
"lerobot/cmu_stretch",
"lerobot/asu_table_top",
"lerobot/utokyo_pr2_tabletop_manipulation",
"lerobot/utokyo_xarm_pick_and_place",
"lerobot/ucsd_kitchen_dataset",
"lerobot/austin_buds_dataset",
"lerobot/dlr_sara_grid_clamp",
"lerobot/conq_hose_manipulation",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/dlr_sara_pour",
"lerobot/dlr_edan_shared_control",
"lerobot/ucsd_pick_and_place_dataset",
"lerobot/berkeley_cable_routing",
"lerobot/nyu_franka_play_dataset",
"lerobot/austin_sirius_dataset",
"lerobot/cmu_play_fusion",
"lerobot/berkeley_gnm_sac_son",
"lerobot/nyu_door_opening_surprising_effectiveness",
"lerobot/berkeley_fanuc_manipulation",
"lerobot/jaco_play",
"lerobot/viola",
"lerobot/kaist_nonprehensile",
"lerobot/berkeley_mvp",
"lerobot/uiuc_d3field",
"lerobot/berkeley_gnm_recon",
"lerobot/austin_sailor_dataset",
"lerobot/utaustin_mutex",
"lerobot/roboturk",
"lerobot/stanford_hydra_dataset",
"lerobot/berkeley_autolab_ur5",
"lerobot/stanford_robocook",
"lerobot/toto",
"lerobot/fmb",
"lerobot/droid_100",
"lerobot/berkeley_rpt",
"lerobot/stanford_kuka_multimodal_dataset",
"lerobot/iamlab_cmu_pickup_insert",
"lerobot/taco_play",
"lerobot/berkeley_gnm_cory_hall",
"lerobot/usc_cloth_sim",
]
available_datasets = list(

View File

@@ -40,6 +40,10 @@ def get_stats_einops_patterns(dataset, num_workers=0):
stats_patterns = {}
for key, feats_type in dataset.features.items():
# NOTE: skip language_instruction embedding in stats computation
if key == "language_instruction":
continue
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64

View File

@@ -60,8 +60,8 @@ AVAILABLE_RAW_REPO_IDS = {
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
"lerobot-raw/pusht_raw": "pusht_zarr",
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
"lerobot-raw/pusht_raw": "pusht_zarr",
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
@@ -70,6 +70,74 @@ AVAILABLE_RAW_REPO_IDS = {
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
"lerobot-raw/fractal20220817_data_raw": "openx_rlds.fractal20220817_data",
"lerobot-raw/kuka_raw": "openx_rlds.kuka",
"lerobot-raw/bridge_openx_raw": "openx_rlds.bridge_openx",
"lerobot-raw/taco_play_raw": "openx_rlds.taco_play",
"lerobot-raw/jaco_play_raw": "openx_rlds.jaco_play",
"lerobot-raw/berkeley_cable_routing_raw": "openx_rlds.berkeley_cable_routing",
"lerobot-raw/roboturk_raw": "openx_rlds.roboturk",
"lerobot-raw/nyu_door_opening_surprising_effectiveness_raw": "openx_rlds.nyu_door_opening_surprising_effectiveness",
"lerobot-raw/viola_raw": "openx_rlds.viola",
"lerobot-raw/berkeley_autolab_ur5_raw": "openx_rlds.berkeley_autolab_ur5",
"lerobot-raw/toto_raw": "openx_rlds.toto",
"lerobot-raw/language_table_raw": "openx_rlds.language_table",
"lerobot-raw/columbia_cairlab_pusht_real_raw": "openx_rlds.columbia_cairlab_pusht_real",
"lerobot-raw/stanford_kuka_multimodal_dataset_raw": "openx_rlds.stanford_kuka_multimodal_dataset",
"lerobot-raw/nyu_rot_dataset_raw": "openx_rlds.nyu_rot_dataset",
"lerobot-raw/io_ai_tech_raw": "openx_rlds.io_ai_tech",
"lerobot-raw/stanford_hydra_dataset_raw": "openx_rlds.stanford_hydra_dataset",
"lerobot-raw/austin_buds_dataset_raw": "openx_rlds.austin_buds_dataset",
"lerobot-raw/nyu_franka_play_dataset_raw": "openx_rlds.nyu_franka_play_dataset",
"lerobot-raw/maniskill_dataset_raw": "openx_rlds.maniskill_dataset",
"lerobot-raw/furniture_bench_dataset_raw": "openx_rlds.furniture_bench_dataset",
"lerobot-raw/cmu_franka_exploration_dataset_raw": "openx_rlds.cmu_franka_exploration_dataset",
"lerobot-raw/ucsd_kitchen_dataset_raw": "openx_rlds.ucsd_kitchen_dataset",
"lerobot-raw/ucsd_pick_and_place_dataset_raw": "openx_rlds.ucsd_pick_and_place_dataset",
"lerobot-raw/spoc_raw": "openx_rlds.spoc",
"lerobot-raw/austin_sailor_dataset_raw": "openx_rlds.austin_sailor_dataset",
"lerobot-raw/austin_sirius_dataset_raw": "openx_rlds.austin_sirius_dataset",
"lerobot-raw/bc_z_raw": "openx_rlds.bc_z",
"lerobot-raw/utokyo_pr2_opening_fridge_raw": "openx_rlds.utokyo_pr2_opening_fridge",
"lerobot-raw/utokyo_pr2_tabletop_manipulation_raw": "openx_rlds.utokyo_pr2_tabletop_manipulation",
"lerobot-raw/utokyo_xarm_pick_and_place_raw": "openx_rlds.utokyo_xarm_pick_and_place",
"lerobot-raw/utokyo_xarm_bimanual_raw": "openx_rlds.utokyo_xarm_bimanual",
"lerobot-raw/utokyo_saytap_raw": "openx_rlds.utokyo_saytap",
"lerobot-raw/robo_net_raw": "openx_rlds.robo_net",
"lerobot-raw/robo_set_raw": "openx_rlds.robo_set",
"lerobot-raw/berkeley_mvp_raw": "openx_rlds.berkeley_mvp",
"lerobot-raw/berkeley_rpt_raw": "openx_rlds.berkeley_rpt",
"lerobot-raw/kaist_nonprehensile_raw": "openx_rlds.kaist_nonprehensile",
"lerobot-raw/stanford_mask_vit_raw": "openx_rlds.stanford_mask_vit",
"lerobot-raw/tokyo_u_lsmo_raw": "openx_rlds.tokyo_u_lsmo",
"lerobot-raw/dlr_sara_pour_raw": "openx_rlds.dlr_sara_pour",
"lerobot-raw/dlr_sara_grid_clamp_raw": "openx_rlds.dlr_sara_grid_clamp",
"lerobot-raw/dlr_edan_shared_control_raw": "openx_rlds.dlr_edan_shared_control",
"lerobot-raw/asu_table_top_raw": "openx_rlds.asu_table_top",
"lerobot-raw/stanford_robocook_raw": "openx_rlds.stanford_robocook",
"lerobot-raw/imperialcollege_sawyer_wrist_cam_raw": "openx_rlds.imperialcollege_sawyer_wrist_cam",
"lerobot-raw/iamlab_cmu_pickup_insert_raw": "openx_rlds.iamlab_cmu_pickup_insert",
"lerobot-raw/uiuc_d3field_raw": "openx_rlds.uiuc_d3field",
"lerobot-raw/utaustin_mutex_raw": "openx_rlds.utaustin_mutex",
"lerobot-raw/berkeley_fanuc_manipulation_raw": "openx_rlds.berkeley_fanuc_manipulation",
"lerobot-raw/cmu_playing_with_food_raw": "openx_rlds.cmu_playing_with_food",
"lerobot-raw/cmu_play_fusion_raw": "openx_rlds.cmu_play_fusion",
"lerobot-raw/cmu_stretch_raw": "openx_rlds.cmu_stretch",
"lerobot-raw/berkeley_gnm_recon_raw": "openx_rlds.berkeley_gnm_recon",
"lerobot-raw/berkeley_gnm_cory_hall_raw": "openx_rlds.berkeley_gnm_cory_hall",
"lerobot-raw/berkeley_gnm_sac_son_raw": "openx_rlds.berkeley_gnm_sac_son",
"lerobot-raw/droid_raw": "openx_rlds.droid",
"lerobot-raw/droid_100_raw": "openx_rlds.droid100",
"lerobot-raw/fmb_raw": "openx_rlds.fmb",
"lerobot-raw/dobbe_raw": "openx_rlds.dobbe",
"lerobot-raw/usc_cloth_sim_raw": "openx_rlds.usc_cloth_sim",
"lerobot-raw/plex_robosuite_raw": "openx_rlds.plex_robosuite",
"lerobot-raw/conq_hose_manipulation_raw": "openx_rlds.conq_hose_manipulation",
"lerobot-raw/vima_raw": "openx_rlds.vima",
"lerobot-raw/robot_vqa_raw": "openx_rlds.robot_vqa",
"lerobot-raw/mimic_play_raw": "openx_rlds.mimic_play",
"lerobot-raw/tidybot_raw": "openx_rlds.tidybot",
"lerobot-raw/eth_agent_affordances_raw": "openx_rlds.eth_agent_affordances",
}
@@ -110,7 +178,7 @@ def download_all_raw_datasets(data_dir: Path | None = None):
def main():
parser = argparse.ArgumentParser(
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}""",
non exhaustive list of available repositories to use in `--repo-id`: {list(AVAILABLE_RAW_REPO_IDS.keys())}""",
)
parser.add_argument(

View File

@@ -0,0 +1,640 @@
OPENX_DATASET_CONFIGS:
fractal20220817_data:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- base_pose_tool_reached
- gripper_closed
fps: 3
kuka:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- clip_function_input/base_pose_tool_reached
- gripper_closed
fps: 10
bridge_openx:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- EEF_state
- gripper_state
fps: 5
taco_play:
image_obs_keys:
- rgb_static
- rgb_gripper
depth_obs_keys:
- depth_static
- depth_gripper
state_obs_keys:
- state_eef
- state_gripper
fps: 15
jaco_play:
image_obs_keys:
- image
- image_wrist
depth_obs_keys:
- null
state_obs_keys:
- state_eef
- state_gripper
fps: 10
berkeley_cable_routing:
image_obs_keys:
- image
- top_image
- wrist45_image
- wrist225_image
depth_obs_keys:
- null
state_obs_keys:
- robot_state
fps: 10
roboturk:
image_obs_keys:
- front_rgb
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
nyu_door_opening_surprising_effectiveness:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 3
viola:
image_obs_keys:
- agentview_rgb
- eye_in_hand_rgb
depth_obs_keys:
- null
state_obs_keys:
- joint_states
- gripper_states
fps: 20
berkeley_autolab_ur5:
image_obs_keys:
- image
- hand_image
depth_obs_keys:
- image_with_depth
state_obs_keys:
- state
fps: 5
toto:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 30
language_table:
image_obs_keys:
- rgb
depth_obs_keys:
- null
state_obs_keys:
- effector_translation
fps: 10
columbia_cairlab_pusht_real:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- robot_state
fps: 10
stanford_kuka_multimodal_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- depth_image
state_obs_keys:
- ee_position
- ee_orientation
fps: 20
nyu_rot_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 3
io_ai_tech:
image_obs_keys:
- image
- image_fisheye
- image_left_side
- image_right_side
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 3
stanford_hydra_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
austin_buds_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
nyu_franka_play_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- image_additional_view
depth_obs_keys:
- depth
- depth_additional_view
state_obs_keys:
- eef_state
fps: 3
maniskill_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- depth
- wrist_depth
state_obs_keys:
- tcp_pose
- gripper_state
fps: 20
furniture_bench_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
cmu_franka_exploration_dataset_converted_externally_to_rlds:
image_obs_keys:
- highres_image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
ucsd_kitchen_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
fps: 2
ucsd_pick_and_place_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 3
spoc:
image_obs_keys:
- image
- image_manipulation
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 3
austin_sailor_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
austin_sirius_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
bc_z:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- present/xyz
- present/axis_angle
- present/sensed_close
fps: 10
utokyo_pr2_opening_fridge_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
utokyo_xarm_pick_and_place_converted_externally_to_rlds:
image_obs_keys:
- image
- image2
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- end_effector_pose
fps: 10
utokyo_xarm_bimanual_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- pose_r
fps: 10
robo_net:
image_obs_keys:
- image
- image1
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 1
robo_set:
image_obs_keys:
- image_left
- image_right
- image_wrist
depth_obs_keys:
- null
state_obs_keys:
- state
- state_velocity
fps: 5
berkeley_mvp_converted_externally_to_rlds:
image_obs_keys:
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- gripper
- pose
- joint_pos
fps: 5
berkeley_rpt_converted_externally_to_rlds:
image_obs_keys:
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- joint_pos
- gripper
fps: 30
kaist_nonprehensile_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
stanford_mask_vit_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
tokyo_u_lsmo_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
dlr_sara_pour_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
dlr_sara_grid_clamp_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
dlr_edan_shared_control_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 5
asu_table_top_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 12.5
stanford_robocook_converted_externally_to_rlds:
image_obs_keys:
- image_1
- image_2
depth_obs_keys:
- depth_1
- depth_2
state_obs_keys:
- eef_state
- gripper_state
fps: 5
imperialcollege_sawyer_wrist_cam:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
iamlab_cmu_pickup_insert_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
- gripper_state
fps: 20
uiuc_d3field:
image_obs_keys:
- image_1
- image_2
depth_obs_keys:
- depth_1
- depth_2
state_obs_keys:
- null
fps: 1
utaustin_mutex:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
berkeley_fanuc_manipulation:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
- gripper_state
fps: 10
cmu_playing_with_food:
image_obs_keys:
- image
- finger_vision_1
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
cmu_play_fusion:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 5
cmu_stretch:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
berkeley_gnm_recon:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 3
berkeley_gnm_cory_hall:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 5
berkeley_gnm_sac_son:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 10
droid:
image_obs_keys:
- exterior_image_1_left
- exterior_image_2_left
- wrist_image_left
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 15
droid_100:
image_obs_keys:
- exterior_image_1_left
- exterior_image_2_left
- wrist_image_left
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 15
fmb:
image_obs_keys:
- image_side_1
- image_side_2
- image_wrist_1
- image_wrist_2
depth_obs_keys:
- image_side_1_depth
- image_side_2_depth
- image_wrist_1_depth
- image_wrist_2_depth
state_obs_keys:
- proprio
fps: 10
dobbe:
image_obs_keys:
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 3.75
usc_cloth_sim_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
plex_robosuite:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
conq_hose_manipulation:
image_obs_keys:
- frontleft_fisheye_image
- frontright_fisheye_image
- hand_color_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 30

View File

@@ -0,0 +1,106 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the Licens e.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
data_utils.py
Additional utils for data processing.
"""
from typing import Any, Dict, List
import tensorflow as tf
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts gripper actions from continuous to binary values (0 and 1).
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
values based on the state that is reached _after_ those intermediate values.
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
chunk of intermediate values as the last action in the trajectory.
The `scan_fn` implements the following logic:
new_actions = np.empty_like(actions)
carry = actions[-1]
for i in reversed(range(actions.shape[0])):
if in_between_mask[i]:
carry = carry
else:
carry = float(open_mask[i])
new_actions[i] = carry
"""
open_mask, closed_mask = actions > 0.95, actions < 0.05
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
is_open_float = tf.cast(open_mask, tf.float32)
def scan_fn(carry, i):
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
return 1 - actions
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
"""
# Note =>> -1 for closing, 1 for opening, 0 for no change
opening_mask, closing_mask = actions < -0.1, actions > 0.1
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
def scan_fn(carry, i):
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
# If no relative grasp, assumes open for whole trajectory
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
start = tf.cond(start == 0, lambda: 1, lambda: start)
# Note =>> -1 for closed, 1 for open
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
return new_actions
# === Bridge-V2 =>> Dataset-Specific Transform ===
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
return traj_truncated
# === RLDS Dataset Initialization Utilities ===
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
print("\n######################################################################################")
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
pad = 80 - len(dataset_kwargs["name"])
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
print("######################################################################################\n")

View File

@@ -0,0 +1,200 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
OpenVLA: https://github.com/openvla/openvla
Episode transforms for DROID dataset.
"""
from typing import Any, Dict
import tensorflow as tf
import tensorflow_graphics.geometry.transformation as tfg
def rmat_to_euler(rot_mat):
return tfg.euler.from_rotation_matrix(rot_mat)
def euler_to_rmat(euler):
return tfg.rotation_matrix_3d.from_euler(euler)
def invert_rmat(rot_mat):
return tfg.rotation_matrix_3d.inverse(rot_mat)
def rotmat_to_rot6d(mat):
"""
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
Args:
mat: rotation matrix
Returns: 6d vector (first two rows of rotation matrix)
"""
r6 = mat[..., :2, :]
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
return r6_flat
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
"""
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
Args:
velocity: 6d velocity action (3 x translation, 3 x rotation)
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
"""
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
r_frame_inv = invert_rmat(r_frame)
# world to wrist: dT_pi = R^-1 dT_rbt
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
# world to wrist: dR_pi = R^-1 dR_rbt R
dr_ = euler_to_rmat(velocity[:, 3:6])
dr_ = r_frame_inv @ (dr_ @ r_frame)
dr_r6 = rotmat_to_rot6d(dr_)
return tf.concat([vel_t, dr_r6], axis=-1)
def rand_swap_exterior_images(img1, img2):
"""
Randomly swaps the two exterior images (for training with single exterior input).
"""
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dr_,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
"""
wrist_act = velocity_act_to_wrist_frame(
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
)
trajectory["action"] = tf.concat(
(
wrist_act,
trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dr_,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def zero_action_filter(traj: Dict) -> bool:
"""
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
"""
droid_q01 = tf.convert_to_tensor(
[
-0.7776297926902771,
-0.5803514122962952,
-0.5795090794563293,
-0.6464047729969025,
-0.7041108310222626,
-0.8895104378461838,
]
)
droid_q99 = tf.convert_to_tensor(
[
0.7597932070493698,
0.5726242214441299,
0.7351000607013702,
0.6705610305070877,
0.6464948207139969,
0.8897542208433151,
]
)
droid_norm_0_act = (
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
)
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)

View File

@@ -0,0 +1,859 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
OpenVLA: https://github.com/openvla/openvla
Octo: https://github.com/octo-models/octo
transforms.py
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
Transforms adopt the following structure:
Input: Dictionary of *batched* features (i.e., has leading time dimension)
Output: Dictionary `step` =>> {
"observation": {
<image_keys, depth_image_keys>
State (in chosen state representation)
},
"action": Action (in chosen action representation),
"language_instruction": str
}
"""
from typing import Any, Dict
import tensorflow as tf
from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import (
binarize_gripper_actions,
invert_gripper_actions,
rel2abs_gripper_actions,
relabel_bridge_actions,
)
def droid_baseact_transform_fn():
from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform
return droid_baseact_transform
def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to version of Bridge V2 in Open X-Embodiment mixture.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key in ["observation", "action"]:
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to original version of Bridge V2 from the official project website.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key == "observation":
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
return trajectory
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
# decode compressed state
eef_value = tf.io.decode_compressed(
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
compression_type="ZLIB",
)
eef_value = tf.io.decode_raw(eef_value, tf.float32)
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
gripper_value = tf.io.decode_compressed(
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
)
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
trajectory["action"] = trajectory["action"]["rel_actions_world"]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
:, -1:
]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
tf.zeros_like(trajectory["action"]["world_vector"]),
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert absolute gripper action, +1 = open, 0 = close
gripper_action = invert_gripper_actions(
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"]
return trajectory
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
gripper_action = invert_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# default to "open" gripper
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.ones_like(trajectory["action"][:, :1]),
),
axis=-1,
)
# decode language instruction
instruction_bytes = trajectory["observation"]["instruction"]
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
:, 0
]
return trajectory
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
trajectory["action"]["gripper_closedness_action"][:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
trajectory["action"] = trajectory["action"][..., :7]
return trajectory
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
trajectory["observation"]["state"][:, 7:10],
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
return trajectory
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
return trajectory
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
trajectory["observation"]["depth_additional_view"] = tf.cast(
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
)
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
# clip gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, -8:-2],
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
),
axis=-1,
)
return trajectory
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
return trajectory
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["state"][:, :7],
trajectory["observation"]["state"][:, -1:],
),
axis=-1,
)
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["future/xyz_residual"][:, :3],
trajectory["action"]["future/axis_angle_residual"][:, :3],
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., -7:]
return trajectory
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :4],
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
trajectory["observation"]["state"] = tf.concat((
tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32),
trajectory["observation"]["pose"],
trajectory["observation"]["joint_pos"],),
axis=-1,)
"""
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
return trajectory
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
return trajectory
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["end_effector_pose"][:, :4],
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
return trajectory
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
return trajectory
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, 7:8],
),
axis=-1,
)
return trajectory
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"],
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
),
axis=-1,
)
return trajectory
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
trajectory["action"][:, -4:],
),
axis=-1,
)
return trajectory
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["position"],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
trajectory["observation"]["yaw"],
),
axis=-1,
)
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["eef_pose"],
trajectory["observation"]["state_gripper_pose"][..., None],
),
axis=-1,
)
return trajectory
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
return trajectory
def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# gripper action is in -1...1 --> clip to 0...1, flip
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :7],
gripper_action,
),
axis=-1,
)
return trajectory
def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
# === Registry ===
OPENX_STANDARDIZATION_TRANSFORMS = {
"bridge_openx": bridge_openx_dataset_transform,
"bridge_orig": bridge_orig_dataset_transform,
"bridge_dataset": bridge_orig_dataset_transform,
"ppgm": ppgm_dataset_transform,
"ppgm_static": ppgm_dataset_transform,
"ppgm_wrist": ppgm_dataset_transform,
"fractal20220817_data": rt1_dataset_transform,
"kuka": kuka_dataset_transform,
"taco_play": taco_play_dataset_transform,
"jaco_play": jaco_play_dataset_transform,
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
"roboturk": roboturk_dataset_transform,
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
"viola": viola_dataset_transform,
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
"toto": toto_dataset_transform,
"language_table": language_table_dataset_transform,
"columbia_cairlab_pusht_real": pusht_dataset_transform,
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
"bc_z": bc_z_dataset_transform,
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform,
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
"robo_net": robo_net_dataset_transform,
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
"dlr_sara_pour_converted_externally_to_rlds": identity_transform,
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
"uiuc_d3field": uiuc_d3field_dataset_transform,
"utaustin_mutex": utaustin_mutex_dataset_transform,
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
"cmu_play_fusion": playfusion_dataset_transform,
"cmu_stretch": cmu_stretch_dataset_transform,
"berkeley_gnm_recon": gnm_dataset_transform,
"berkeley_gnm_cory_hall": gnm_dataset_transform,
"berkeley_gnm_sac_son": gnm_dataset_transform,
"droid": droid_baseact_transform_fn(),
"droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid
"fmb": fmb_transform,
"dobbe": dobbe_dataset_transform,
"robo_set": robo_set_dataset_transform,
"usc_cloth_sim_converted_externally_to_rlds": identity_transform,
"plex_robosuite": identity_transform,
"conq_hose_manipulation": identity_transform,
"io_ai_tech": identity_transform,
"spoc": identity_transform,
}

View File

@@ -0,0 +1,359 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
Example:
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
--repo-id youliangtan/sampled_bridge_data_v2 \
--raw-format openx_rlds.bridge_orig \
--episodes 3 4 5 8 9
Exact dataset fps defined in openx/config.py, obtained from:
https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0&range=R:R
"""
import shutil
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
import yaml
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.push_dataset_to_hub.utils import (
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml", "r") as f:
_openx_list = yaml.safe_load(f)
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
np.set_printoptions(precision=2)
def tf_to_torch(data):
return torch.from_numpy(data.numpy())
def tf_img_convert(img):
if img.dtype == tf.string:
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)
elif img.dtype != tf.uint8:
raise ValueError(f"Unsupported image dtype: found with dtype {img.dtype}")
return img.numpy()
def _broadcast_metadata_rlds(i: tf.Tensor, traj: dict) -> dict:
"""
In the RLDS format, each trajectory has some top-level metadata that is explicitly separated out, and a "steps"
entry. This function moves the "steps" entry to the top level, broadcasting any metadata to the length of the
trajectory. This function also adds the extra metadata fields `_len`, `_traj_index`, and `_frame_index`.
NOTE: adapted from DLimp library https://github.com/kvablack/dlimp/
"""
steps = traj.pop("steps")
traj_len = tf.shape(tf.nest.flatten(steps)[0])[0]
# broadcast metadata to the length of the trajectory
metadata = tf.nest.map_structure(lambda x: tf.repeat(x, traj_len), traj)
# put steps back in
assert "traj_metadata" not in steps
traj = {**steps, "traj_metadata": metadata}
assert "_len" not in traj
assert "_traj_index" not in traj
assert "_frame_index" not in traj
traj["_len"] = tf.repeat(traj_len, traj_len)
traj["_traj_index"] = tf.repeat(i, traj_len)
traj["_frame_index"] = tf.range(traj_len)
return traj
def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
encoding: dict | None = None,
openx_dataset_name: str | None = None,
):
"""
Args:
raw_dir (Path): _description_
videos_dir (Path): _description_
fps (int): _description_
video (bool): _description_
episodes (list[int] | None, optional): _description_. Defaults to None.
"""
ds_builder = tfds.builder_from_directory(str(raw_dir))
dataset = ds_builder.as_dataset(
split="all",
decoders={"steps": tfds.decode.SkipDecoding()},
)
dataset_info = ds_builder.info
print("dataset_info: ", dataset_info)
ds_length = len(dataset)
dataset = dataset.take(ds_length)
# "flatten" the dataset as such we can apply trajectory level map() easily
# each [obs][key] has a shape of (frame_size, ...)
dataset = dataset.enumerate().map(_broadcast_metadata_rlds)
# we will apply the standardization transform if the dataset_name is provided
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
# search for 'image' keys in the observations
if openx_dataset_name is not None:
print(" - applying standardization transform for dataset: ", openx_dataset_name)
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
dataset = dataset.map(transform_fn)
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
else:
obs_keys = dataset_info.features["steps"]["observation"].keys()
image_keys = [key for key in obs_keys if "image" in key]
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
print(" - image_keys: ", image_keys)
print(" - lang_key: ", lang_key)
it = iter(dataset)
ep_dicts = []
# Init temp path to save ep_dicts in case of crash
tmp_ep_dicts_dir = videos_dir.parent.joinpath("ep_dicts")
tmp_ep_dicts_dir.mkdir(parents=True, exist_ok=True)
# check if ep_dicts have already been saved in /tmp
starting_ep_idx = 0
saved_ep_dicts = [ep.__str__() for ep in tmp_ep_dicts_dir.iterdir()]
if len(saved_ep_dicts) > 0:
saved_ep_dicts.sort()
# get last ep_idx number
starting_ep_idx = int(saved_ep_dicts[-1][-13:-3]) + 1
for i in range(starting_ep_idx):
episode = next(it)
ep_dicts.append(torch.load(saved_ep_dicts[i]))
# if we user specified episodes, skip the ones not in the list
if episodes is not None:
if ds_length == 0:
raise ValueError("No episodes found.")
# convert episodes index to sorted list
episodes = sorted(episodes)
for ep_idx in tqdm.tqdm(range(starting_ep_idx, ds_length)):
episode = next(it)
# if user specified episodes, skip the ones not in the list
if episodes is not None:
if len(episodes) == 0:
break
if ep_idx == episodes[0]:
# process this episode
print(" selecting episode idx: ", ep_idx)
episodes.pop(0)
else:
continue # skip
num_frames = episode["action"].shape[0]
###########################################################
# Handle the episodic data
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
ep_dict = {}
langs = [] # TODO: might be located in "observation"
image_array_dict = {key: [] for key in image_keys}
# We will create the state observation tensor by stacking the state
# obs keys defined in the openx/configs.py
if openx_dataset_name is not None:
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
# stack the state observations, if is None, pad with zeros
states = []
for key in state_obs_keys:
if key in episode["observation"]:
states.append(tf_to_torch(episode["observation"][key]))
else:
states.append(torch.zeros(num_frames, 1)) # pad with zeros
states = torch.cat(states, dim=1)
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
else:
states = tf_to_torch(episode["observation"]["state"])
actions = tf_to_torch(episode["action"])
rewards = tf_to_torch(episode["reward"]).float()
# If lang_key is present, convert the entire tensor at once
if lang_key is not None:
langs = [str(x) for x in episode[lang_key]]
for im_key in image_keys:
imgs = episode["observation"][im_key]
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
# simple assertions
for item in [states, actions, rewards, done]:
assert len(item) == num_frames
###########################################################
# loop through all cameras
for im_key in image_keys:
img_key = f"observation.images.{im_key}"
imgs_array = image_array_dict[im_key]
imgs_array = np.array(imgs_array)
if video:
# save png images in temporary directory
tmp_imgs_dir = videos_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
if lang_key is not None:
ep_dict["language_instruction"] = langs
ep_dict["observation.state"] = states
ep_dict["action"] = actions
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["next.reward"] = rewards
ep_dict["next.done"] = done
path_ep_dict = tmp_ep_dicts_dir.joinpath(
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
)
torch.save(ep_dict, path_ep_dict)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
if "language_instruction" in data_dict:
features["language_instruction"] = Value(dtype="string", id=None)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.reward"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
encoding: dict | None = None,
openx_dataset_name: str | None = None,
):
"""This is a test impl for rlds conversion"""
if openx_dataset_name is None:
# set a default rlds frame rate if the dataset is not from openx
fps = 30
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
raise ValueError(
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
)
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}
if video:
info["encoding"] = get_default_encoding()
return hf_dataset, episode_data_index, info

View File

@@ -23,11 +23,19 @@ from typing import Dict
import datasets
import torch
from datasets import load_dataset, load_from_disk
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
DATASET_CARD_TEMPLATE = """
---
# Metadata will go there
---
This dataset was created using [🤗 LeRobot](https://github.com/huggingface/lerobot).
"""
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
@@ -72,6 +80,11 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif isinstance(first_item, str):
# TODO (michel-aractingi): add str2embedding via language tokenizer
# For now we leave this part up to the user to choose how to address
# language conditioned tasks
pass
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
@@ -385,3 +398,29 @@ def cycle(iterable):
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""
api = HfApi()
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
refs = [branch.ref for branch in branches]
ref = f"refs/heads/{branch}"
if ref in refs:
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"]
if tags is not None:
card.data.tags += tags
if text is not None:
card.text += text
return card

View File

@@ -210,6 +210,12 @@ def encode_video_frames(
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
if not video_path.exists():
raise OSError(
f"Video encoding did not work. File not found: {video_path}. "
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
)
@dataclass
class VideoFrame:

View File

@@ -38,7 +38,13 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
class ACTPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "act"],
):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@@ -154,6 +160,31 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict
def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for ACT"""
optimizer_params_dicts = [
{
"params": [
p
for n, p in self.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
return optimizer, lr_scheduler
class ACTTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:

View File

@@ -43,7 +43,13 @@ from lerobot.common.policies.utils import (
)
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
class DiffusionPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "diffusion-policy"],
):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
@@ -111,12 +117,12 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
Schematically this looks like:
----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
@@ -150,6 +156,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for Diffusion policy"""
optimizer = torch.optim.Adam(
self.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
return optimizer, lr_scheduler
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""

View File

@@ -41,7 +41,13 @@ from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
class TDMPCPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc"],
):
"""Implementation of TD-MPC learning + inference.
Please note several warnings for this policy.
@@ -528,6 +534,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for TD-MPC"""
optimizer = torch.optim.Adam(self.parameters(), cfg.training.lr)
lr_scheduler = None
return optimizer, lr_scheduler
class TDMPCTOLD(nn.Module):
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""

View File

@@ -38,7 +38,13 @@ from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
# ruff: noqa: N806
class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
class VQBeTPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vqbet"],
):
"""
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
"""
@@ -146,6 +152,12 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict
def make_optimizer_and_scheduler(self, cfg):
"""Create the optimizer and learning rate scheduler for VQ-BeT"""
optimizer = VQBeTOptimizer(self, cfg)
scheduler = VQBeTScheduler(optimizer, cfg)
return optimizer, scheduler
class SpatialSoftmax(nn.Module):
"""
@@ -289,7 +301,7 @@ class VQBeTModel(nn.Module):
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.output_shapes["action"][0], hidden_channels=[self.config.gpt_input_dim]
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]

View File

@@ -5,6 +5,7 @@ This file contains utilities for recording frames from cameras. For more info lo
import argparse
import concurrent.futures
import math
import platform
import shutil
import threading
import time
@@ -33,8 +34,22 @@ MAX_OPENCV_INDEX = 60
def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX):
if platform.system() == "Linux":
# Linux uses camera ports
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
possible_camera_ids = []
for port in Path("/dev").glob("video*"):
camera_idx = int(str(port).replace("/dev/video", ""))
possible_camera_ids.append(camera_idx)
else:
print(
"Mac or Windows detected. Finding available camera indices through "
f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
)
possible_camera_ids = range(max_index_search_range)
camera_ids = []
for camera_idx in range(max_index_search_range):
for camera_idx in possible_camera_ids:
camera = cv2.VideoCapture(camera_idx)
is_open = camera.isOpened()
camera.release()
@@ -45,7 +60,8 @@ def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENC
if raise_when_empty and len(camera_ids) == 0:
raise OSError(
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2."
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, "
"or your camera driver, or make sure your camera is compatible with opencv2."
)
return camera_ids
@@ -59,10 +75,9 @@ def save_image(img_array, camera_index, frame_index, images_dir):
def save_images_from_cameras(
images_dir: Path, camera_ids=None, fps=None, width=None, height=None, record_time_s=2
images_dir: Path, camera_ids: list[int] | None = None, fps=None, width=None, height=None, record_time_s=2
):
if camera_ids is None:
print("Finding available camera indices")
camera_ids = find_camera_indices()
print("Connecting cameras")
@@ -71,13 +86,12 @@ def save_images_from_cameras(
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height)
camera.connect()
print(
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
f"height={camera.height}, color_mode={camera.color_mode})"
)
cameras.append(camera)
images_dir = Path(
images_dir,
)
images_dir = Path(images_dir)
if images_dir.exists():
shutil.rmtree(
images_dir,
@@ -160,7 +174,7 @@ class OpenCVCamera:
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
of the given camera will be used.
Example of usage of the class:
Example of usage:
```python
camera = OpenCVCamera(camera_index=0)
camera.connect()
@@ -194,11 +208,6 @@ class OpenCVCamera:
self.height = config.height
self.color_mode = config.color_mode
if not isinstance(self.camera_index, int):
raise ValueError(
f"Camera index must be provided as an int, but {self.camera_index} was given instead."
)
self.camera = None
self.is_connected = False
self.thread = None
@@ -212,7 +221,13 @@ class OpenCVCamera:
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(self.camera_index)
if platform.system() == "Linux":
# Linux uses ports for connecting to cameras
tmp_camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
else:
tmp_camera = cv2.VideoCapture(self.camera_index)
is_camera_open = tmp_camera.isOpened()
# Release camera to make it accessible for `find_camera_indices`
del tmp_camera
@@ -224,7 +239,8 @@ class OpenCVCamera:
available_cam_ids = find_camera_indices()
if self.camera_index not in available_cam_ids:
raise ValueError(
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead."
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
)
raise OSError(f"Can't access camera {self.camera_index}.")
@@ -232,7 +248,10 @@ class OpenCVCamera:
# Secondly, create the camera that will be used downstream.
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
# needs to be re-created.
self.camera = cv2.VideoCapture(self.camera_index)
if platform.system() == "Linux":
self.camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
else:
self.camera = cv2.VideoCapture(self.camera_index)
if self.fps is not None:
self.camera.set(cv2.CAP_PROP_FPS, self.fps)

View File

@@ -2,6 +2,7 @@ from pathlib import Path
from typing import Protocol
import cv2
import einops
import numpy as np
@@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False):
cv2.imwrite(str(path), depth_image)
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
assert tensor.ndim == 3
c, h, w = tensor.shape
assert c < h and c < w
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
if rgb_to_bgr:
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
return color_image
# Defines a camera type
class Camera(Protocol):
def connect(self): ...

View File

@@ -5,6 +5,7 @@ from copy import deepcopy
from pathlib import Path
import numpy as np
import tqdm
from dynamixel_sdk import (
COMM_SUCCESS,
DXL_HIBYTE,
@@ -21,9 +22,11 @@ from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError,
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 2.0
BAUD_RATE = 1_000_000
BAUDRATE = 1_000_000
TIMEOUT_MS = 1000
MAX_ID_RANGE = 252
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
@@ -86,6 +89,16 @@ X_SERIES_CONTROL_TABLE = {
"Present_Temperature": (146, 1),
}
X_SERIES_BAUDRATE_TABLE = {
0: 9_600,
1: 57_600,
2: 115_200,
3: 1_000_000,
4: 2_000_000,
5: 3_000_000,
6: 4_000_000,
}
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
@@ -98,7 +111,67 @@ MODEL_CONTROL_TABLE = {
"xm540-w270": X_SERIES_CONTROL_TABLE,
}
MODEL_RESOLUTION = {
"x_series": 4096,
"xl330-m077": 4096,
"xl330-m288": 4096,
"xl430-w250": 4096,
"xm430-w350": 4096,
"xm540-w270": 4096,
}
MODEL_BAUDRATE_TABLE = {
"x_series": X_SERIES_BAUDRATE_TABLE,
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
}
NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]):
"""This function convert the degree range to the step range for indicating motors rotation.
It assums a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
"""
if isinstance(degrees, float):
degrees = np.array(degrees)
resolutions = [MODEL_RESOLUTION[model] for model in models]
steps = degrees / 180 * np.array(resolutions) / 2
steps = steps.astype(int)
return steps
def convert_to_bytes(value, bytes):
# Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us.
if bytes == 1:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
]
elif bytes == 2:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
]
elif bytes == 4:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
DXL_LOBYTE(DXL_HIWORD(value)),
DXL_HIBYTE(DXL_HIWORD(value)),
]
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead."
)
return data
def get_group_sync_key(data_name, motor_names):
@@ -207,13 +280,12 @@ class DynamixelMotorsBus:
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
>>> Reconnect the usb cable.
```
To find the motor indices, use [DynamixelWizzard2](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2).
Example of usage for 1 motor connected to the bus:
```python
motor_name = "gripper"
motor_index = 6
motor_model = "xl330-m077"
motor_model = "xl330-m288"
motors_bus = DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
@@ -221,7 +293,11 @@ class DynamixelMotorsBus:
)
motors_bus.connect()
motors_bus.teleop_step()
position = motors_bus.read("Present_Position")
# move from a few motor steps as an example
few_steps = 30
motors_bus.write("Goal_Position", position + few_steps)
# when done, consider disconnecting
motors_bus.disconnect()
@@ -233,6 +309,7 @@ class DynamixelMotorsBus:
port: str,
motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None,
extra_model_resolution: dict[str, int] | None = None,
):
self.port = port
self.motors = motors
@@ -241,6 +318,10 @@ class DynamixelMotorsBus:
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.model_resolution = deepcopy(MODEL_RESOLUTION)
if extra_model_resolution:
self.model_resolution.update(extra_model_resolution)
self.port_handler = None
self.packet_handler = None
self.calibration = None
@@ -268,52 +349,286 @@ class DynamixelMotorsBus:
)
raise
self.port_handler.setBaudRate(BAUD_RATE)
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
# Allow to read and write
self.is_connected = True
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
# Set expected baudrate for the bus
self.set_bus_baudrate(BAUDRATE)
if not self.are_motors_configured():
input(
"\n/!\\ A configuration issue has been detected with your motors: \n"
"If it's the first time that you use these motors, press enter to configure your motors... but before "
"verify that all the cables are connected the proper way. If you find an issue, before making a modification, "
"kill the python process, unplug the power cord to not damage the motors, rewire correctly, then plug the power "
"again and relaunch the script.\n"
)
print()
self.configure_motors()
def reconnect(self):
self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
self.is_connected = True
def are_motors_configured(self):
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
# a ConnectionError will be raised anyway.
try:
return (self.motor_indices == self.read("ID")).all()
except ConnectionError as e:
print(e)
return False
def configure_motors(self):
# TODO(rcadene): This script assumes motors follow the X_SERIES baudrates
# TODO(rcadene): Refactor this function with intermediate high-level functions
print("Scanning all baudrates and motor indices")
all_baudrates = set(X_SERIES_BAUDRATE_TABLE.values())
ids_per_baudrate = {}
for baudrate in all_baudrates:
self.set_bus_baudrate(baudrate)
present_ids = self.find_motor_indices()
if len(present_ids) > 0:
ids_per_baudrate[baudrate] = present_ids
print(f"Motor indices detected: {ids_per_baudrate}")
print()
possible_baudrates = list(ids_per_baudrate.keys())
possible_ids = list({idx for sublist in ids_per_baudrate.values() for idx in sublist})
untaken_ids = list(set(range(MAX_ID_RANGE)) - set(possible_ids) - set(self.motor_indices))
# Connect successively one motor to the chain and write a unique random index for each
for i in range(len(self.motors)):
self.disconnect()
input(
"1. Unplug the power cord\n"
"2. Plug/unplug minimal number of cables to only have the first "
f"{i+1} motor(s) ({self.motor_names[:i+1]}) connected.\n"
"3. Re-plug the power cord\n"
"Press Enter to continue..."
)
print()
self.reconnect()
if i > 0:
try:
self._read_with_motor_ids(self.motor_models, untaken_ids[:i], "ID")
except ConnectionError:
print(f"Failed to read from {untaken_ids[:i+1]}. Make sure the power cord is plugged in.")
input("Press Enter to continue...")
print()
self.reconnect()
print("Scanning possible baudrates and motor indices")
motor_found = False
for baudrate in possible_baudrates:
self.set_bus_baudrate(baudrate)
present_ids = self.find_motor_indices(possible_ids)
if len(present_ids) == 1:
present_idx = present_ids[0]
print(f"Detected motor with index {present_idx}")
if baudrate != BAUDRATE:
print(f"Setting its baudrate to {BAUDRATE}")
baudrate_idx = list(X_SERIES_BAUDRATE_TABLE.values()).index(BAUDRATE)
# The write can fail, so we allow retries
for _ in range(NUM_WRITE_RETRY):
self._write_with_motor_ids(
self.motor_models, present_idx, "Baud_Rate", baudrate_idx
)
time.sleep(0.5)
self.set_bus_baudrate(BAUDRATE)
try:
present_baudrate_idx = self._read_with_motor_ids(
self.motor_models, present_idx, "Baud_Rate"
)
except ConnectionError:
print("Failed to write baudrate. Retrying.")
self.set_bus_baudrate(baudrate)
continue
break
else:
raise
if present_baudrate_idx != baudrate_idx:
raise OSError("Failed to write baudrate.")
print(f"Setting its index to a temporary untaken index ({untaken_ids[i]})")
self._write_with_motor_ids(self.motor_models, present_idx, "ID", untaken_ids[i])
present_idx = self._read_with_motor_ids(self.motor_models, untaken_ids[i], "ID")
if present_idx != untaken_ids[i]:
raise OSError("Failed to write index.")
motor_found = True
break
elif len(present_ids) > 1:
raise OSError(f"More than one motor detected ({present_ids}), but only one was expected.")
if not motor_found:
raise OSError(
"No motor found, but one new motor expected. Verify power cord is plugged in and retry."
)
print()
print(f"Setting expected motor indices: {self.motor_indices}")
self.set_bus_baudrate(BAUDRATE)
self._write_with_motor_ids(
self.motor_models, untaken_ids[: len(self.motors)], "ID", self.motor_indices
)
print()
if (self.read("ID") != self.motor_indices).any():
raise OSError("Failed to write motors indices.")
print("Configuration is done!")
def find_motor_indices(self, possible_ids=None):
if possible_ids is None:
possible_ids = range(MAX_ID_RANGE)
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self._read_with_motor_ids(self.motor_models, [idx], "ID")[0]
except ConnectionError:
continue
if idx != present_idx:
# sanity check
raise OSError(
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
)
indices.append(idx)
return indices
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
raise OSError("Failed to write bus baud rate.")
@property
def motor_names(self) -> list[int]:
def motor_names(self) -> list[str]:
return list(self.motors.keys())
@property
def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()]
@property
def motor_indices(self) -> list[int]:
return [idx for idx, _ in self.motors.values()]
def set_calibration(self, calibration: dict[str, tuple[int, bool]]):
self.calibration = calibration
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration:
return values
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
in the centered nominal degree range ]-180, 180[.
"""
if motor_names is None:
motor_names = self.motor_names
# Convert from unsigned int32 original range [0, 2**32[ to centered signed int32 range [-2**31, 2**31[
values = values.astype(np.int32)
for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name]
if values[i] is not None:
if drive_mode:
values[i] *= -1
values[i] += homing_offset
# Update direction of rotation of the motor to match between leader and follower. In fact, the motor of the leader for a given joint
# can be assembled in an opposite direction in term of rotation than the motor of the follower on the same joint.
if drive_mode:
values[i] *= -1
# Convert from range [-2**31, 2**31[ to nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
values[i] += homing_offset
# Convert from range ]-resolution, resolution[ to the universal float32 centered degree range ]-180, 180[
values = values.astype(np.float32)
for i, name in enumerate(motor_names):
_, model = self.motors[name]
resolution = self.model_resolution[model]
values[i] = values[i] / (resolution // 2) * 180
return values
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration:
return values
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
# Convert from the universal float32 centered degree range ]-180, 180[ to resolution range ]-resolution, resolution[
for i, name in enumerate(motor_names):
_, model = self.motors[name]
resolution = self.model_resolution[model]
values[i] = values[i] / 180 * (resolution // 2)
values = np.round(values).astype(np.int32)
# Convert from nominal range ]-resolution, resolution[ to centered signed int32 range [-2**31, 2**31[
for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name]
values[i] -= homing_offset
if values[i] is not None:
values[i] -= homing_offset
if drive_mode:
values[i] *= -1
# Update direction of rotation of the motor that was matching between leader and follower to their original direction.
# In fact, the motor of the leader for a given joint can be assembled in an opposite direction in term of rotation
# than the motor of the follower on the same joint.
if drive_mode:
values[i] *= -1
return values
def _read_with_motor_ids(self, motor_models, motor_ids, data_name):
return_list = True
if not isinstance(motor_ids, list):
return_list = False
motor_ids = [motor_ids]
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
group = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
for idx in motor_ids:
group.addParam(idx)
comm = group.txRxPacket()
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = group.getData(idx, addr, bytes)
values.append(value)
if return_list:
return values
else:
return values[0]
def read(self, data_name, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
@@ -367,9 +682,21 @@ class DynamixelMotorsBus:
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
values = values.astype(np.int32)
if data_name in CALIBRATION_REQUIRED:
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.apply_calibration(values, motor_names)
# We expect our motors to stay in a nominal range of [-180, 180] degrees
# which corresponds to a half turn rotation.
# However, some motors can turn a bit more, hence we extend the nominal range to [-270, 270]
# which is less than a full 360 degree rotation.
if not np.all((values > -270) & (values < 270)):
raise ValueError(
f"Wrong motor position range detected. "
f"Expected to be in [-270, +270] but in [{values.min()}, {values.max()}]. "
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
@@ -380,6 +707,26 @@ class DynamixelMotorsBus:
return values
def _write_with_motor_ids(self, motor_models, motor_ids, data_name, values):
if not isinstance(motor_ids, list):
motor_ids = [motor_ids]
if not isinstance(values, list):
values = [values]
assert_same_address(self.model_ctrl_table, motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
group = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes)
group.addParam(idx, data)
comm = group.txPacket()
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
@@ -406,7 +753,7 @@ class DynamixelMotorsBus:
motor_ids.append(motor_idx)
models.append(model)
if data_name in CALIBRATION_REQUIRED:
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.revert_calibration(values, motor_names)
values = values.tolist()
@@ -422,30 +769,7 @@ class DynamixelMotorsBus:
)
for idx, value in zip(motor_ids, values, strict=True):
# Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us.
if bytes == 1:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
]
elif bytes == 2:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
]
elif bytes == 4:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
DXL_LOBYTE(DXL_HIWORD(value)),
DXL_HIBYTE(DXL_HIWORD(value)),
]
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead."
)
data = convert_to_bytes(value, bytes)
if init_group:
self.group_writers[group_key].addParam(idx, data)
else:

View File

@@ -1,46 +1,7 @@
def make_robot(name):
if name == "koch":
# TODO(rcadene): Add configurable robot from command line and yaml config
# TODO(rcadene): Add example with and without cameras
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.robots.koch import KochRobot
import hydra
from omegaconf import DictConfig
robot = KochRobot(
leader_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl330-m077"),
"shoulder_lift": (2, "xl330-m077"),
"elbow_flex": (3, "xl330-m077"),
"wrist_flex": (4, "xl330-m077"),
"wrist_roll": (5, "xl330-m077"),
"gripper": (6, "xl330-m077"),
},
),
},
follower_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
},
),
},
cameras={
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
},
)
else:
raise ValueError(f"Robot '{name}' not found.")
def make_robot(cfg: DictConfig):
robot = hydra.utils.instantiate(cfg)
return robot

View File

@@ -1,129 +1,51 @@
import logging
import pickle
import time
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Sequence
import numpy as np
import torch
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
TorqueMode,
convert_degrees_to_steps,
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
URL_HORIZONTAL_POSITION = {
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_horizontal.png",
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_horizontal.png",
}
URL_90_DEGREE_POSITION = {
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_90_degree.png",
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_90_degree.png",
}
########################################################################
# Calibration logic
########################################################################
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
GRIPPER_OPEN = np.array([-400])
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# In nominal degree range ]-180, +180[
ZERO_POSITION_DEGREE = 0
ROTATED_POSITION_DEGREE = 90
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
for i in range(len(values)):
if values[i] is not None:
values[i] += homing_offset[i]
return values
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
for i in range(len(values)):
if values[i] is not None and drive_mode[i]:
values[i] = -values[i]
return values
def apply_drive_mode(position, drive_mode):
assert_drive_mode(drive_mode)
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
signed_drive_mode = -(drive_mode * 2 - 1)
position *= signed_drive_mode
return position
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
values = apply_drive_mode(values, drive_mode)
values = apply_homing_offset(values, homing_offset)
return values
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
"""
Transform working position into real position for the robot.
"""
values = apply_homing_offset(
values,
np.array([-homing_offset if homing_offset is not None else None for homing_offset in homing_offset]),
)
values = apply_drive_mode(values, drive_mode)
return values
def revert_appropriate_positions(positions: np.array, drive_mode: list[bool]) -> np.array:
for i, revert in enumerate(drive_mode):
if not revert and positions[i] is not None:
positions[i] = -positions[i]
return positions
def compute_corrections(positions: np.array, drive_mode: list[bool], target_position: np.array) -> np.array:
correction = revert_appropriate_positions(positions, drive_mode)
for i in range(len(positions)):
if correction[i] is not None:
if drive_mode[i]:
correction[i] -= target_position[i]
else:
correction[i] += target_position[i]
return correction
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
return np.array(
[
round(positions[i] / 1024) * 1024 if positions[i] is not None else None
for i in range(len(positions))
]
)
def compute_homing_offset(
arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array
) -> np.array:
# Get the present positions of the servos
present_positions = apply_calibration(
arm.read("Present_Position"), np.array([0, 0, 0, 0, 0, 0]), drive_mode
)
nearest_positions = compute_nearest_rounded_positions(present_positions)
correction = compute_corrections(nearest_positions, drive_mode, target_position)
return correction
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
# Get current positions
present_positions = apply_calibration(
arm.read("Present_Position"), offset, np.array([False, False, False, False, False, False])
)
nearest_positions = compute_nearest_rounded_positions(present_positions)
# construct 'drive_mode' list comparing nearest_positions and TARGET_90_DEGREE_POSITION
drive_mode = []
for i in range(len(nearest_positions)):
drive_mode.append(nearest_positions[i] != TARGET_90_DEGREE_POSITION[i])
return drive_mode
def reset_arm(arm: MotorsBus):
def reset_torque_mode(arm: MotorsBus):
# To be configured, all servos must be in "torque disable" mode
arm.write("Torque_Enable", TorqueMode.DISABLED.value)
@@ -132,55 +54,95 @@ def reset_arm(arm: MotorsBus):
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
if len(all_motors_except_gripper) > 0:
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
# TODO(rcadene): why?
# Use 'position control current based' for gripper
# Use 'position control current based' for gripper to be limited by the limit of the current.
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
# it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
# to make it move, and it will move back to its original target position when we release the force.
arm.write("Operating_Mode", OperatingMode.CURRENT_CONTROLLED_POSITION.value, "gripper")
# Make sure the native calibration (homing offset abd drive mode) is disabled, since we use our own calibration layer to be more generic
arm.write("Homing_Offset", 0)
arm.write("Drive_Mode", DriveMode.NON_INVERTED.value)
def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str):
"""Example of usage:
"""This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration,
the two robots will move to the same position.To this end, this function computes the homing offset
and the drive mode for each motor of a given robot.
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
to the "rotated position".
After calibration, the homing offsets and drive modes are stored in a cache.
Example of usage:
```python
run_arm_calibration(arm, "left", "follower")
```
"""
reset_arm(arm)
reset_torque_mode(arm)
# TODO(rcadene): document what position 1 mean
print(
f"Please move the '{name} {arm_type}' arm to the horizontal position (gripper fully closed, see {URL_HORIZONTAL_POSITION[arm_type]})"
)
print(f"\nRunning calibration of {name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="zero"))
input("Press Enter to continue...")
horizontal_homing_offset = compute_homing_offset(
arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION
)
# We arbitrarely choosed our zero target position to be a straight horizontal position with gripper upwards and closed.
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
# corresponds to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
zero_position = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
# TODO(rcadene): document what position 2 mean
print(
f"Please move the '{name} {arm_type}' arm to the 90 degree position (gripper fully open, see {URL_90_DEGREE_POSITION[arm_type]})"
)
def _compute_nearest_rounded_position(position, models):
# TODO(rcadene): Rework this function since some motors cant physically rotate a quarter turn
# (e.g. the gripper of Aloha arms can only rotate ~50 degree)
quarter_turn_degree = 90
quarter_turn = convert_degrees_to_steps(quarter_turn_degree, models)
nearest_pos = np.round(position.astype(float) / quarter_turn) * quarter_turn
return nearest_pos.astype(position.dtype)
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
position = arm.read("Present_Position")
position = _compute_nearest_rounded_position(position, arm.motor_models)
homing_offset = zero_position - position
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rotated"))
input("Press Enter to continue...")
drive_mode = compute_drive_mode(arm, horizontal_homing_offset)
homing_offset = compute_homing_offset(arm, drive_mode, TARGET_90_DEGREE_POSITION)
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
# This allows to identify the rotation direction of each motor.
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
rotated_position = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
# Invert offset for all drive_mode servos
for i in range(len(drive_mode)):
if drive_mode[i]:
homing_offset[i] = -homing_offset[i]
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
position = arm.read("Present_Position")
position += homing_offset
position = _compute_nearest_rounded_position(position, arm.motor_models)
drive_mode = (position != rotated_position).astype(np.int32)
print("Calibration is done!")
# Re-compute homing offset to take into account drive mode
position = arm.read("Present_Position")
position = apply_drive_mode(position, drive_mode)
position = _compute_nearest_rounded_position(position, arm.motor_models)
homing_offset = rotated_position - position
print("=====================================")
print(" HOMING_OFFSET: ", " ".join([str(i) for i in homing_offset]))
print(" DRIVE_MODE: ", " ".join([str(i) for i in drive_mode]))
print("=====================================")
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rest"))
input("Press Enter to continue...")
print()
return homing_offset, drive_mode
@@ -204,10 +166,39 @@ class KochRobotConfig:
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
cameras: dict[str, Camera] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
def __setattr__(self, prop: str, val):
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(val):
raise ValueError(
f"len(max_relative_target)={len(val)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
super().__setattr__(prop, val)
class KochRobot:
# TODO(rcadene): Implement force feedback
"""Tau Robotics: https://tau-robotics.com
"""This class allows to control any Koch robot of various number of motors.
A few versions are available:
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, which was developed
by Alexander Koch from [Tau Robotics](https://tau-robotics.com): [Github for sourcing and assembly](
- [Koch v1.1])https://github.com/jess-moss/koch-v1-1), which was developed by Jess Moss.
Example of highest frequency teleoperation without camera:
```python
@@ -240,7 +231,10 @@ class KochRobot:
},
),
}
robot = KochRobot(leader_arms, follower_arms)
robot = KochRobot(
leader_arms=leader_arms,
follower_arms=follower_arms,
)
# Connect motors buses and cameras if any (Required)
robot.connect()
@@ -252,7 +246,10 @@ class KochRobot:
Example of highest frequency data collection without camera:
```python
# Assumes leader and follower arms have been instantiated already (see first example)
robot = KochRobot(leader_arms, follower_arms)
robot = KochRobot(
leader_arms=leader_arms,
follower_arms=follower_arms,
)
robot.connect()
while True:
observation, action = robot.teleop_step(record_data=True)
@@ -261,16 +258,20 @@ class KochRobot:
Example of highest frequency data collection with cameras:
```python
# Defines how to communicate with 2 cameras connected to the computer.
# Here, the webcam of the mackbookpro and the iphone (connected in USB to the macbookpro)
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
# can be reached respectively using the camera indices 0 and 1. These indices can be
# arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices.
cameras = {
"macbookpro": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
"iphone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
}
# Assumes leader and follower arms have been instantiated already (see first example)
robot = KochRobot(leader_arms, follower_arms, cameras)
robot = KochRobot(
leader_arms=leader_arms,
follower_arms=follower_arms,
cameras=cameras,
)
robot.connect()
while True:
observation, action = robot.teleop_step(record_data=True)
@@ -279,7 +280,11 @@ class KochRobot:
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency):
```python
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
robot = KochRobot(leader_arms, follower_arms, cameras)
robot = KochRobot(
leader_arms=leader_arms,
follower_arms=follower_arms,
cameras=cameras,
)
robot.connect()
while True:
# Uses the follower arms and cameras to capture an observation
@@ -330,23 +335,27 @@ class KochRobot:
# Connect the arms
for name in self.follower_arms:
print(f"Connecting {name} follower arm.")
self.follower_arms[name].connect()
print(f"Connecting {name} leader arm.")
self.leader_arms[name].connect()
# Reset the arms and load or run calibration
if self.calibration_path.exists():
# Reset all arms before setting calibration
for name in self.follower_arms:
reset_arm(self.follower_arms[name])
reset_torque_mode(self.follower_arms[name])
for name in self.leader_arms:
reset_arm(self.leader_arms[name])
reset_torque_mode(self.leader_arms[name])
with open(self.calibration_path, "rb") as f:
calibration = pickle.load(f)
else:
print(f"Missing calibration file '{self.calibration_path}'. Starting calibration precedure.")
# Run calibration process which begins by reseting all arms
calibration = self.run_calibration()
print(f"Calibration is done! Saving calibration file '{self.calibration_path}'")
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.calibration_path, "wb") as f:
pickle.dump(calibration, f)
@@ -366,13 +375,15 @@ class KochRobot:
# Enable torque on all motors of the follower arms
for name in self.follower_arms:
print(f"Activating torque on {name} follower arm.")
self.follower_arms[name].write("Torque_Enable", 1)
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN, "gripper")
if self.config.gripper_open_degree is not None:
# Set the leader arm in torque mode with the gripper motor set to an angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
# Connect the cameras
for name in self.cameras:
@@ -407,12 +418,12 @@ class KochRobot:
"KochRobot is not connected. You need to run `robot.connect()`."
)
# Prepare to assign the positions of the leader to the follower
# Prepare to assign the position of the leader to the follower
leader_pos = {}
for name in self.leader_arms:
now = time.perf_counter()
before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position")
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - now
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
follower_goal_pos = {}
for name in self.leader_arms:
@@ -420,9 +431,9 @@ class KochRobot:
# Send action
for name in self.follower_arms:
now = time.perf_counter()
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - now
before_fwrite_t = time.perf_counter()
self.send_action(torch.tensor(follower_goal_pos[name]), [name])
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
# Early exit when recording data is not requested
if not record_data:
@@ -432,9 +443,9 @@ class KochRobot:
# Read follower position
follower_pos = {}
for name in self.follower_arms:
now = time.perf_counter()
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
@@ -453,10 +464,10 @@ class KochRobot:
# Capture images from cameras
images = {}
for name in self.cameras:
now = time.perf_counter()
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries and format to pytorch
obs_dict, action_dict = {}, {}
@@ -477,9 +488,9 @@ class KochRobot:
# Read follower position
follower_pos = {}
for name in self.follower_arms:
now = time.perf_counter()
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
@@ -491,37 +502,69 @@ class KochRobot:
# Capture images from cameras
images = {}
for name in self.cameras:
now = time.perf_counter()
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = torch.from_numpy(state)
for name in self.cameras:
# Convert to pytorch format: channel first and float32 in [0,1]
img = torch.from_numpy(images[name])
img = img.type(torch.float32) / 255
img = img.permute(2, 0, 1).contiguous()
obs_dict[f"observation.images.{name}"] = img
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
return obs_dict
def send_action(self, action: torch.Tensor):
"""The provided action is expected to be a vector."""
def send_action(self, action: torch.Tensor, follower_names: list[str] | None = None):
"""Command the follower arms to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`.
Args:
action: tensor containing the concatenated joint positions for the follower arms.
follower_names: Pass follower arm names to only control a subset of all the follower arms.
"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"KochRobot is not connected. You need to run `robot.connect()`."
)
if follower_names is None:
follower_names = list(self.follower_arms)
elif not set(follower_names).issubset(self.follower_arms):
raise ValueError(
f"You provided {follower_names=} but only the following arms are registered: "
f"{list(self.follower_arms)}"
)
from_idx = 0
to_idx = 0
follower_goal_pos = {}
for name in self.follower_arms:
if name in self.follower_arms:
to_idx += len(self.follower_arms[name].motor_names)
follower_goal_pos[name] = action[from_idx:to_idx].numpy()
from_idx = to_idx
for name in follower_names:
to_idx += len(self.follower_arms[name].motor_names)
this_action = action[from_idx:to_idx]
if self.config.max_relative_target is not None:
if not isinstance(self.config.max_relative_target, list):
max_relative_target = [self.config.max_relative_target for _ in range(from_idx, to_idx)]
max_relative_target = torch.tensor(self.config.max_relative_target)
# Cap relative action target magnitude for safety.
current_pos = torch.tensor(self.follower_arms[name].read("Present_Position"))
diff = this_action - current_pos
safe_diff = torch.minimum(diff, max_relative_target)
safe_diff = torch.maximum(safe_diff, -max_relative_target)
safe_action = current_pos + safe_diff
if not torch.allclose(safe_action, this_action):
logging.warning(
"Relative action magnitude had to be clamped to be safe.\n"
f" requested relative action target: {diff}\n"
f" clamped relative action target: {safe_diff}"
)
follower_goal_pos[name] = safe_action.numpy()
else:
follower_goal_pos[name] = this_action.numpy()
from_idx = to_idx
for name in self.follower_arms:
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import os.path as osp
import random
from contextlib import contextmanager
@@ -27,6 +28,12 @@ import torch
from omegaconf import DictConfig
def inside_slurm():
"""Check whether the python process was launched through slurm"""
# TODO(rcadene): return False for interactive mode `--pty bash`
return "SLURM_JOB_ID" in os.environ
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match cfg_device:

View File

@@ -120,7 +120,7 @@ eval:
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
batch_size: 1
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: false
use_async_envs: true
wandb:
enable: false

View File

@@ -2,6 +2,11 @@
fps: 50
eval:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# set it to false to avoid some problems of the aloha env
use_async_envs: false
env:
name: aloha
task: AlohaInsertion-v0

View File

@@ -2,6 +2,11 @@
fps: 15
eval:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# set it to false to avoid some problems of the aloha env
use_async_envs: false
env:
name: xarm
task: XarmLift-v0

View File

@@ -0,0 +1,46 @@
_target_: lerobot.common.robot_devices.robots.koch.KochRobot
calibration_path: .cache/calibration/koch.pkl
leader_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0031751
motors:
# name: (index, model)
shoulder_pan: [1, "xl330-m077"]
shoulder_lift: [2, "xl330-m077"]
elbow_flex: [3, "xl330-m077"]
wrist_flex: [4, "xl330-m077"]
wrist_roll: [5, "xl330-m077"]
gripper: [6, "xl330-m077"]
follower_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0032081
motors:
# name: (index, model)
shoulder_pan: [1, "xl430-w250"]
shoulder_lift: [2, "xl430-w250"]
elbow_flex: [3, "xl330-m288"]
wrist_flex: [4, "xl330-m288"]
wrist_roll: [5, "xl330-m288"]
gripper: [6, "xl330-m288"]
cameras:
laptop:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
fps: 30
width: 640
height: 480
phone:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
fps: 30
width: 640
height: 480
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: null
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: 35.156

View File

@@ -0,0 +1,68 @@
_target_: lerobot.common.robot_devices.robots.koch.KochRobot
calibration_path: .cache/calibration/koch_bimanual.pkl
leader_arms:
left:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem585A0085511
motors:
# name: (index, model)
shoulder_pan: [1, "xl330-m077"]
shoulder_lift: [2, "xl330-m077"]
elbow_flex: [3, "xl330-m077"]
wrist_flex: [4, "xl330-m077"]
wrist_roll: [5, "xl330-m077"]
gripper: [6, "xl330-m077"]
right:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0031751
motors:
# name: (index, model)
shoulder_pan: [1, "xl330-m077"]
shoulder_lift: [2, "xl330-m077"]
elbow_flex: [3, "xl330-m077"]
wrist_flex: [4, "xl330-m077"]
wrist_roll: [5, "xl330-m077"]
gripper: [6, "xl330-m077"]
follower_arms:
left:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem585A0076891
motors:
# name: (index, model)
shoulder_pan: [1, "xl430-w250"]
shoulder_lift: [2, "xl430-w250"]
elbow_flex: [3, "xl330-m288"]
wrist_flex: [4, "xl330-m288"]
wrist_roll: [5, "xl330-m288"]
gripper: [6, "xl330-m288"]
right:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0032081
motors:
# name: (index, model)
shoulder_pan: [1, "xl430-w250"]
shoulder_lift: [2, "xl430-w250"]
elbow_flex: [3, "xl330-m288"]
wrist_flex: [4, "xl330-m288"]
wrist_roll: [5, "xl330-m288"]
gripper: [6, "xl330-m288"]
cameras:
laptop:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
fps: 30
width: 640
height: 480
phone:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
fps: 30
width: 640
height: 480
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: null
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: 35.156

View File

@@ -1,9 +1,22 @@
"""
Utilities to control a robot.
Useful to record a dataset, replay a recorded episode, run the policy on your robot
and record an evaluation dataset, and to recalibrate your robot if needed.
Examples of usage:
- Recalibrate your robot:
```bash
python lerobot/scripts/control_robot.py calibrate
```
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
```bash
python lerobot/scripts/control_robot.py teleoperate
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
python lerobot/scripts/control_robot.py teleoperate --robot-overrides '~cameras'
```
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
@@ -14,7 +27,7 @@ python lerobot/scripts/control_robot.py teleoperate \
- Record one episode in order to test replay:
```bash
python lerobot/scripts/control_robot.py record_dataset \
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
@@ -32,7 +45,7 @@ python lerobot/scripts/visualize_dataset.py \
- Replay this test episode:
```bash
python lerobot/scripts/control_robot.py replay_episode \
python lerobot/scripts/control_robot.py replay \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
@@ -42,12 +55,11 @@ python lerobot/scripts/control_robot.py replay_episode \
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
```bash
python lerobot/scripts/control_robot.py record_dataset \
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root data \
--repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \
--run-compute-stats 1 \
--warmup-time-s 2 \
--episode-time-s 30 \
--reset-time-s 10
@@ -74,7 +86,14 @@ DATA_DIR=data python lerobot/scripts/train.py \
- Run the pretrained policy on the robot:
```bash
python lerobot/scripts/control_robot.py run_policy \
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root data \
--repo-id $USER/eval_act_koch_real \
--num-episodes 10 \
--warmup-time-s 2 \
--episode-time-s 30 \
--reset-time-s 10
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
```
"""
@@ -87,12 +106,14 @@ import os
import platform
import shutil
import time
import traceback
from contextlib import nullcontext
from functools import cache
from pathlib import Path
import cv2
import torch
import tqdm
from huggingface_hub import create_branch
from omegaconf import DictConfig
from PIL import Image
from termcolor import colored
@@ -102,20 +123,45 @@ from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
from lerobot.common.datasets.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
from lerobot.scripts.push_dataset_to_hub import (
push_dataset_card_to_hub,
push_meta_data_to_hub,
push_videos_to_hub,
save_meta_data,
)
########################################################################################
# Utilities
########################################################################################
def say(text, blocking=False):
# Check if mac, linux, or windows.
if platform.system() == "Darwin":
cmd = f'say "{text}"'
elif platform.system() == "Linux":
cmd = f'spd-say "{text}"'
elif platform.system() == "Windows":
cmd = (
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
)
if not blocking and platform.system() in ["Darwin", "Linux"]:
# TODO(rcadene): Make it work for Windows
# Use the ampersand to run command in the background
cmd += " &"
os.system(cmd)
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
img = Image.fromarray(img_tensor.numpy())
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
@@ -160,11 +206,11 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
for name in robot.follower_arms:
key = f"write_follower_{name}_goal_pos_dt_s"
if key in robot.logs:
log_dt("dtRfoll", robot.logs[key])
log_dt("dtWfoll", robot.logs[key])
key = f"read_follower_{name}_pos_dt_s"
if key in robot.logs:
log_dt("dtWfoll", robot.logs[key])
log_dt("dtRfoll", robot.logs[key])
for name in robot.cameras:
key = f"read_camera_{name}_dt_s"
@@ -179,12 +225,23 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
logging.info(info_str)
def get_is_headless():
if platform.system() == "Linux":
display = os.environ.get("DISPLAY")
if display is None or display == "":
return True
return False
@cache
def is_headless():
"""Detects if python is running without a monitor."""
try:
import pynput # noqa
return False
except Exception:
print(
"Error trying to import pynput. Switching to headless mode. "
"As a result, the video stream from the cameras won't be shown, "
"and you won't be able to change the control flow with keyboards. "
"For more info, see traceback below.\n"
)
traceback.print_exc()
print()
return True
########################################################################################
@@ -192,29 +249,44 @@ def get_is_headless():
########################################################################################
def calibrate(robot: Robot):
if robot.calibration_path.exists():
print(f"Removing '{robot.calibration_path}'")
robot.calibration_path.unlink()
if robot.is_connected:
robot.disconnect()
# Calling `connect` automatically runs calibration
# when the calibration file is missing
robot.connect()
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
robot.connect()
start_time = time.perf_counter()
start_teleop_t = time.perf_counter()
while True:
now = time.perf_counter()
start_loop_t = time.perf_counter()
robot.teleop_step()
if fps is not None:
dt_s = time.perf_counter() - now
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
break
def record_dataset(
def record(
robot: Robot,
policy: torch.nn.Module | None = None,
hydra_cfg: DictConfig | None = None,
fps: int | None = None,
root="data",
repo_id="lerobot/debug",
@@ -225,10 +297,18 @@ def record_dataset(
video=True,
run_compute_stats=True,
push_to_hub=True,
tags=None,
num_image_writers=8,
force_override=False,
):
# TODO(rcadene): Add option to record logs
# TODO(rcadene): Clean this function via decomposition in higher level functions
_, dataset_name = repo_id.split("/")
if dataset_name.startswith("eval_") and policy is None:
raise ValueError(
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
)
if not video:
raise NotImplementedError()
@@ -255,32 +335,10 @@ def record_dataset(
else:
episode_index = 0
is_headless = get_is_headless()
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_time = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
os.system('say "Warmup" &')
is_warmup_print = True
now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True)
if not is_headless:
image_keys = [key for key in observation if "image" in key]
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
if is_headless():
logging.info(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
@@ -290,9 +348,7 @@ def record_dataset(
stop_recording = False
# Only import pynput if not in a headless environment
if is_headless:
logging.info("Headless environment detected. Keyboard input will not be available.")
else:
if not is_headless():
from pynput import keyboard
def on_press(key):
@@ -315,6 +371,53 @@ def record_dataset(
listener = keyboard.Listener(on_press=on_press)
listener.start()
# Load policy if any
if policy is not None:
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
# override fps using policy fps
fps = hydra_cfg.env.fps
# Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing.
timestamp = 0
start_warmup_t = time.perf_counter()
is_warmup_print = False
while timestamp < warmup_time_s:
if not is_warmup_print:
logging.info("Warming up (no data recording)")
say("Warming up")
is_warmup_print = True
start_loop_t = time.perf_counter()
if policy is None:
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_warmup_t
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread.
@@ -323,14 +426,18 @@ def record_dataset(
# Start recording all episodes
while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &')
say(f"Recording episode {episode_index}")
ep_dict = {}
frame_index = 0
timestamp = 0
start_time = time.perf_counter()
start_episode_t = time.perf_counter()
while timestamp < episode_time_s:
now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True)
start_loop_t = time.perf_counter()
if policy is None:
observation, action = robot.teleop_step(record_data=True)
else:
observation = robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
not_image_keys = [key for key in observation if "image" not in key]
@@ -342,11 +449,46 @@ def record_dataset(
)
]
if not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
for key in not_image_keys:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(observation[key])
if policy is not None:
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
# Order the robot to move
robot.send_action(action)
action = {"action": action}
for key in action:
if key not in ep_dict:
ep_dict[key] = []
@@ -354,14 +496,13 @@ def record_dataset(
frame_index += 1
dt_s = time.perf_counter() - now
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_time
timestamp = time.perf_counter() - start_episode_t
if exit_early:
exit_early = False
break
@@ -369,10 +510,10 @@ def record_dataset(
if not stop_recording:
# Start resetting env while the executor are finishing
logging.info("Reset the environment")
os.system('say "Reset the environment" &')
say("Reset the environment")
timestamp = 0
start_time = time.perf_counter()
start_vencod_t = time.perf_counter()
# During env reset we save the data and encode the videos
num_frames = frame_index
@@ -418,7 +559,7 @@ def record_dataset(
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s and not is_last_episode:
time.sleep(1)
timestamp = time.perf_counter() - start_time
timestamp = time.perf_counter() - start_vencod_t
pbar.update(1)
if exit_early:
exit_early = False
@@ -433,8 +574,8 @@ def record_dataset(
if is_last_episode:
logging.info("Done recording")
os.system('say "Done recording"')
if not is_headless:
say("Done recording", blocking=True)
if not is_headless():
listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...")
@@ -444,10 +585,14 @@ def record_dataset(
pass
break
robot.disconnect()
if not is_headless():
cv2.destroyAllWindows()
num_episodes = episode_index
logging.info("Encoding videos")
os.system('say "Encoding videos" &')
say("Encoding videos")
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
@@ -455,6 +600,7 @@ def record_dataset(
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
@@ -491,11 +637,12 @@ def record_dataset(
)
if run_compute_stats:
logging.info("Computing dataset statistics")
os.system('say "Computing dataset statistics" &')
say("Computing dataset statistics")
stats = compute_stats(lerobot_dataset)
lerobot_dataset.stats = stats
else:
logging.info("Skipping computation of the dataset statistrics")
stats = {}
logging.info("Skipping computation of the dataset statistics")
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
@@ -506,17 +653,17 @@ def record_dataset(
if push_to_hub:
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
logging.info("Exiting")
os.system('say "Exiting" &')
say("Exiting")
return lerobot_dataset
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
@@ -531,76 +678,20 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
robot.connect()
logging.info("Replaying episode")
os.system('say "Replaying episode"')
say("Replaying episode", blocking=True)
for idx in range(from_idx, to_idx):
now = time.perf_counter()
start_episode_t = time.perf_counter()
action = items[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - now
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
dt_s = time.perf_counter() - start_episode_t
log_control_info(robot, dt_s, fps=fps)
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
# TODO(rcadene): Add option to record eval dataset and logs
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
policy.eval()
policy.to(device)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
fps = hydra_cfg.env.fps
if not robot.is_connected:
robot.connect()
start_time = time.perf_counter()
while True:
now = time.perf_counter()
observation = robot.capture_observation()
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and hydra_cfg.use_amp
else nullcontext(),
):
# add batch dimension to 1
for name in observation:
observation[name] = observation[name].unsqueeze(0)
if device.type == "mps":
for name in observation:
observation[name] = observation[name].to(device)
action = policy.select_action(observation)
# remove batch dimension
action = action.squeeze(0)
robot.send_action(action.to("cpu"))
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - now
log_control_info(robot, dt_s, fps=fps)
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
@@ -608,18 +699,26 @@ if __name__ == "__main__":
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument(
"--robot",
"--robot-path",
type=str,
default="koch",
help="Name of the robot provided to the `make_robot(name)` factory function.",
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
base_parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_calib = subparsers.add_parser("calibrate", parents=[base_parser])
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
@@ -638,19 +737,19 @@ if __name__ == "__main__":
parser_record.add_argument(
"--warmup-time-s",
type=int,
default=2,
default=10,
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
)
parser_record.add_argument(
"--episode-time-s",
type=int,
default=10,
default=60,
help="Number of seconds for data recording for each episode.",
)
parser_record.add_argument(
"--reset-time-s",
type=int,
default=5,
default=60,
help="Number of seconds for resetting the environment after each episode.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
@@ -666,6 +765,12 @@ if __name__ == "__main__":
default=1,
help="Upload dataset to Hugging Face hub.",
)
parser_record.add_argument(
"--tags",
type=str,
nargs="*",
help="Add tags to your dataset on the hub.",
)
parser_record.add_argument(
"--num-image-writers",
type=int,
@@ -678,8 +783,23 @@ if __name__ == "__main__":
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
)
parser_record.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_record.add_argument(
"--policy-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
@@ -697,41 +817,46 @@ if __name__ == "__main__":
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_policy.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
init_logging()
control_mode = args.mode
robot_name = args.robot
robot_path = args.robot_path
robot_overrides = args.robot_overrides
kwargs = vars(args)
del kwargs["mode"]
del kwargs["robot"]
del kwargs["robot_path"]
del kwargs["robot_overrides"]
robot = make_robot(robot_name)
if control_mode == "teleoperate":
robot_cfg = init_hydra_config(robot_path, robot_overrides)
robot = make_robot(robot_cfg)
if control_mode == "calibrate":
calibrate(robot, **kwargs)
elif control_mode == "teleoperate":
teleoperate(robot, **kwargs)
elif control_mode == "record_dataset":
record_dataset(robot, **kwargs)
elif control_mode == "replay_episode":
replay_episode(robot, **kwargs)
elif control_mode == "run_policy":
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", args.overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
run_policy(robot, policy, hydra_cfg)
elif control_mode == "record":
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
policy_overrides = args.policy_overrides
del kwargs["pretrained_policy_name_or_path"]
del kwargs["policy_overrides"]
policy_cfg = None
if pretrained_policy_name_or_path is not None:
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
record(robot, policy, policy_cfg, **kwargs)
else:
record(robot, **kwargs)
elif control_mode == "replay":
replay(robot, **kwargs)
if robot.is_connected:
# Disconnect manually to avoid a "Core dump" during process
# termination due to camera threads not properly exiting.
robot.disconnect()

View File

@@ -70,7 +70,13 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_hydra_config,
init_logging,
inside_slurm,
set_global_seed,
)
def rollout(
@@ -79,7 +85,6 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
enable_progbar: bool = False,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
@@ -109,7 +114,6 @@ def rollout(
are returned optionally because they typically take more memory to cache. Defaults to False.
render_callback: Optional rendering callback to be used after the environments are reset, and after
every step.
enable_progbar: Enable a progress bar over rollout steps.
Returns:
The dictionary described above.
"""
@@ -136,7 +140,7 @@ def rollout(
progbar = trange(
max_steps,
desc=f"Running rollout with at most {max_steps} steps",
disable=not enable_progbar,
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
leave=False,
)
while not np.all(done):
@@ -210,8 +214,6 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
enable_progbar: bool = False,
enable_inner_progbar: bool = False,
) -> dict:
"""
Args:
@@ -224,8 +226,6 @@ def eval_policy(
the "episodes" key of the returned dictionary.
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
seed is incremented by 1. If not provided, the environments are not manually seeded.
enable_progbar: Enable progress bar over batches.
enable_inner_progbar: Enable progress bar over steps in each batch.
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
@@ -266,7 +266,8 @@ def eval_policy(
if return_episode_data:
episode_data: dict | None = None
progbar = trange(n_batches, desc="Stepping through eval batches", disable=not enable_progbar)
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
@@ -285,7 +286,6 @@ def eval_policy(
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
enable_progbar=enable_inner_progbar,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -454,6 +454,16 @@ def main(
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
if hydra_cfg.eval.batch_size > hydra_cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({hydra_cfg.eval.batch_size} > {hydra_cfg.eval.n_episodes}). As a result, {hydra_cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {hydra_cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={hydra_cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={hydra_cfg.eval.n_episodes}`)."
)
if out_dir is None:
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
@@ -487,8 +497,6 @@ def main(
max_episodes_rendered=10,
videos_dir=Path(out_dir) / "videos",
start_seed=hydra_cfg.seed,
enable_progbar=True,
enable_inner_progbar=True,
)
print(info["aggregated"])

View File

@@ -56,7 +56,7 @@ from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
from lerobot.common.datasets.utils import flatten_dict
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
def get_from_raw_to_lerobot_format_fn(raw_format: str):
@@ -66,6 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif "openx_rlds" in raw_format:
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
@@ -114,6 +116,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
)
def push_dataset_card_to_hub(
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
):
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
card = create_lerobot_dataset_card(tags=tags, text=text)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
"""Expect mp4 files to be all stored in a single "videos" directory.
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
@@ -189,9 +199,25 @@ def push_dataset_to_hub(
# convert dataset from original raw format to LeRobot format
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps, video, episodes, encoding
)
fmt_kwgs = {
"raw_dir": raw_dir,
"videos_dir": videos_dir,
"fps": fps,
"video": video,
"episodes": episodes,
"encoding": encoding,
}
if "openx_rlds." in raw_format:
# Support for official OXE dataset name inside `raw_format`.
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
_, openx_dataset_name = raw_format.split(".")
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
@@ -213,10 +239,10 @@ def push_dataset_to_hub(
if push_to_hub:
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_dataset_card_to_hub(repo_id, revision="main")
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
api = HfApi()
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
if tests_data_dir:
# get the first episode
@@ -260,7 +286,7 @@ def main():
"--raw-format",
type=str,
required=True,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).",
)
parser.add_argument(
"--repo-id",
@@ -320,6 +346,13 @@ def main():
default=0,
help="When set to 1, resumes a previous run.",
)
parser.add_argument(
"--cache-dir",
type=Path,
required=False,
default="/tmp",
help="Directory to store the temporary videos and images generated while creating the dataset.",
)
parser.add_argument(
"--tests-data-dir",
type=Path,

View File

@@ -51,59 +51,6 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.eval import eval_policy
def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
else:
raise NotImplementedError()
return optimizer, lr_scheduler
def update_policy(
policy,
batch,
@@ -241,6 +188,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
raise NotImplementedError()
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
@@ -287,6 +235,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
)
if cfg.eval.batch_size > cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
)
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
@@ -323,7 +281,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
assert isinstance(policy, nn.Module)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
optimizer, lr_scheduler = policy.make_optimizer_and_scheduler(cfg)
grad_scaler = GradScaler(enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)

View File

@@ -108,8 +108,8 @@ def visualize_dataset(
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
root: Path | None = None,
output_dir: Path | None = None,
) -> Path | None:
if save:
assert (
@@ -209,6 +209,18 @@ def main():
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument(
"--batch-size",
type=int,
@@ -254,17 +266,6 @@ def main():
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
),
)
parser.add_argument(
"--output-dir",
type=str,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument(
"--root",
type=str,
help="Root directory for a dataset stored on a local machine.",
)
args = parser.parse_args()
visualize_dataset(**vars(args))

View File

@@ -0,0 +1,304 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Example of usage:
- Visualize data stored on a local machine:
```bash
local$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```bash
distant$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
```
"""
import argparse
import logging
import shutil
from pathlib import Path
import tqdm
from flask import Flask, redirect, render_template, url_for
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import init_logging
def run_server(
dataset: LeRobotDataset,
episodes: list[int],
host: str,
port: str,
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
def index():
# home page redirects to the first episode page
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
first_episode_id = episodes[0]
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=first_episode_id,
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id):
dataset_info = {
"repo_id": dataset.repo_id,
"num_samples": dataset.num_samples,
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
video_paths = get_episode_video_paths(dataset, episode_id)
language_instruction = get_episode_language_instruction(dataset, episode_id)
videos_info = [
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
for video_path in video_paths
]
if language_instruction:
videos_info[0]["language_instruction"] = language_instruction
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
ep_csv_url=ep_csv_url,
has_policy=False,
)
app.run(host=host, port=port)
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
header += [f"action_{i}" for i in range(dim_action)]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.video_frame_keys
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.hf_dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def visualize_dataset_html(
repo_id: str,
root: Path | None = None,
episodes: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
) -> Path | None:
init_logging()
dataset = LeRobotDataset(repo_id, root=root)
if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates"
if episodes is None:
episodes = list(range(dataset.num_episodes))
logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
args = parser.parse_args()
visualize_dataset_html(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,375 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<!-- # TODO(rcadene, mishig25): store the js files locally -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/alpinejs/3.13.5/cdn.min.js" defer></script>
<script src="https://cdn.jsdelivr.net/npm/dygraphs@2.2.1/dist/dygraph.min.js" type="text/javascript"></script>
<script src="https://cdn.tailwindcss.com"></script>
<title>{{ dataset_info.repo_id }} episode {{ episode_id }}</title>
</head>
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
<body class="flex h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
const { keyCode, key } = e;
if (keyCode === 32 || key === ' ') {
e.preventDefault();
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
const lowestEpisodeId = {{ episodes }}.at(0);
const highestEpisodeId = {{ episodes }}.at(-1);
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
window.location.href = `./episode_${nextEpisodeId}`;
}
}
}">
<!-- Sidebar -->
<div x-ref="sidebar" class="w-60 bg-slate-900 p-5 break-words max-h-screen overflow-y-auto">
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
<ul>
<li>
Number of samples/frames: {{ dataset_info.num_samples }}
</li>
<li>
Number of episodes: {{ dataset_info.num_episodes }}
</li>
<li>
Frames per second: {{ dataset_info.fps }}
</li>
</ul>
<p>Episodes:</p>
<ul class="ml-2">
{% for episode in episodes %}
<li class="font-mono text-sm mt-0.5">
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
Episode {{ episode }}
</a>
</li>
{% endfor %}
</ul>
</div>
<!-- Toggle sidebar button -->
<button class="flex items-center opacity-50 hover:opacity-100 mx-1"
@click="() => ($refs.sidebar.classList.toggle('hidden'))" title="Toggle sidebar">
<div class="bg-slate-500 w-2 h-10 rounded-full"></div>
</button>
<!-- Content -->
<div class="flex-1 max-h-screen flex flex-col gap-4 overflow-y-auto">
<h1 class="text-xl font-bold mt-4 font-mono">
Episode {{ episode_id }}
</h1>
<!-- Videos -->
<div class="flex flex-wrap gap-1">
{% for video_info in videos_info %}
<div class="max-w-96">
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<video muted loop type="video/mp4" class="min-w-64" @canplay="videoCanPlay" @timeupdate="() => {
if (video.duration) {
const time = video.currentTime;
const pc = (100 / video.duration) * time;
$refs.slider.value = pc;
dygraphTime = time;
dygraphIndex = Math.floor(pc * dygraph.numRows() / 100);
dygraph.setSelection(dygraphIndex, undefined, true, true);
$refs.timer.textContent = formatTime(time) + ' / ' + formatTime(video.duration);
updateTimeQuery(time.toFixed(2));
}
}" @ended="() => {
$refs.btnPlay.classList.remove('hidden');
$refs.btnPause.classList.add('hidden');
}"
@loadedmetadata="() => ($refs.timer.textContent = formatTime(0) + ' / ' + formatTime(video.duration))">
<source src="{{ video_info.url }}">
Your browser does not support the video tag.
</video>
</div>
{% endfor %}
</div>
<!-- Language instruction -->
{% if videos_info[0].language_instruction %}
<p class="font-medium mt-2">
Language Instruction: <span class="italic">{{ videos_info[0].language_instruction }}</span>
</p>
{% endif %}
<!-- Shortcuts info -->
<div class="text-sm hidden md:block">
Hotkeys: <span class="font-mono">Space</span> to pause/unpause, <span class="font-mono">Arrow Down</span> to go to next episode, <span class="font-mono">Arrow Up</span> to go to previous episode.
</div>
<!-- Controllers -->
<div class="flex gap-1 text-3xl items-center">
<button x-ref="btnPlay" class="-rotate-90" class="-rotate-90" title="Play. Toggle with Space" @click="() => {
videos.forEach(video => video.play());
$refs.btnPlay.classList.toggle('hidden');
$refs.btnPause.classList.toggle('hidden');
}">🔽</button>
<button x-ref="btnPause" class="hidden" title="Pause. Toggle with Space" @click="() => {
videos.forEach(video => video.pause());
$refs.btnPlay.classList.toggle('hidden');
$refs.btnPause.classList.toggle('hidden');
}">⏸️</button>
<button title="Jump backward 5 seconds"
@click="() => (videos.forEach(video => (video.currentTime -= 5)))"></button>
<button title="Jump forward 5 seconds"
@click="() => (videos.forEach(video => (video.currentTime += 5)))"></button>
<button title="Rewind from start"
@click="() => (videos.forEach(video => (video.currentTime = 0.0)))">↩️</button>
<input x-ref="slider" max="100" min="0" step="1" type="range" value="0" class="w-80 mx-2" @input="() => {
const sliderValue = $refs.slider.value;
videos.forEach(video => {
const time = (video.duration * sliderValue) / 100;
video.currentTime = time;
});
}" />
<div x-ref="timer" class="font-mono text-sm border border-slate-500 rounded-lg px-1 py-0.5 shrink-0">0:00 /
0:00
</div>
</div>
<!-- Graph -->
<div class="flex gap-2 mb-4 flex-wrap">
<div>
<div id="graph" @mouseleave="() => {
dygraph.setSelection(dygraphIndex, undefined, true, true);
dygraphTime = video.currentTime;
}">
</div>
<p x-ref="graphTimer" class="font-mono ml-14 mt-4"
x-init="$watch('dygraphTime', value => ($refs.graphTimer.innerText = `Time: ${dygraphTime.toFixed(2)}s`))">
Time: 0.00s
</p>
</div>
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
<thead>
<tr>
<th></th>
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
<th class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2">
<input type="checkbox" :checked="isColumnChecked(colIndex)"
@change="toggleColumn(colIndex)">
<p x-text="`${columnNames[colIndex]}`"></p>
</div>
</th>
</template>
</tr>
</thead>
<tbody>
<template x-for="(row, rowIndex) in rows">
<tr class="odd:bg-gray-800 even:bg-gray-900">
<td class="border border-slate-700">
<div class="flex gap-x-2 w-24 font-semibold px-1">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
<p x-text="`Motor ${rowIndex}`"></p>
</div>
</td>
<template x-for="(cell, colIndex) in row">
<td x-show="cell" class="border border-slate-700">
<div class="flex gap-x-2 w-24 justify-between px-2">
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
<span x-text="`${cell.value.toFixed(2)}`"
:style="`color: ${cell.color}`"></span>
</div>
</td>
</template>
</tr>
</template>
</tbody>
</table>
<div id="labels" class="hidden">
</div>
</div>
</div>
<script>
function createAlpineData() {
return {
// state
dygraph: null,
currentFrameData: null,
columnNames: ["state", "action", "pred action"],
nColumns: {% if has_policy %}3{% else %}2{% endif %},
checked: [],
dygraphTime: 0.0,
dygraphIndex: 0,
videos: null,
video: null,
colors: null,
nVideos: {{ videos_info | length }},
nVideoReadyToPlay: 0,
// alpine initialization
init() {
this.videos = document.querySelectorAll('video');
this.video = this.videos[0];
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
pixelsPerPoint: 0.01,
legend: 'always',
labelsDiv: document.getElementById('labels'),
labelsKMB: true,
strokeWidth: 1.5,
pointClickCallback: (event, point) => {
this.dygraphTime = point.xval;
this.updateTableValues(this.dygraphTime);
},
highlightCallback: (event, x, points, row, seriesName) => {
this.dygraphTime = x;
this.updateTableValues(this.dygraphTime);
},
drawCallback: (dygraph, is_initial) => {
if (is_initial) {
// dygraph initialization
this.dygraph.setSelection(this.dygraphIndex, undefined, true, true);
this.colors = this.dygraph.getColors();
this.checked = Array(this.colors.length).fill(true);
const seriesNames = this.dygraph.getLabels().slice(1);
const colors = [];
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
let lightnessIdx = 0;
const chunkSize = Math.ceil(seriesNames.length / this.nColumns);
for (let i = 0; i < seriesNames.length; i += chunkSize) {
const lightness = LIGHTNESS[lightnessIdx];
for (let hue = 0; hue < 360; hue += parseInt(360/chunkSize)) {
const color = `hsl(${hue}, 100%, ${lightness}%)`;
colors.push(color);
}
lightnessIdx += 1;
}
this.dygraph.updateOptions({ colors });
this.colors = colors;
this.updateTableValues();
let url = new URL(window.location.href);
let params = new URLSearchParams(url.search);
let time = params.get("t");
if(time){
time = parseFloat(time);
this.videos.forEach(video => (video.currentTime = time));
}
}
},
});
},
//#region Table Data
// turn dygraph's 1D data (at a given time t) to 2D data that whose columns names are defined in this.columnNames.
// 2d data view is used to create html table element.
get rows() {
if (!this.currentFrameData) {
return [];
}
const columnSize = Math.ceil(this.currentFrameData.length / this.nColumns);
return Array.from({
length: columnSize
}, (_, rowIndex) => {
const row = [
this.currentFrameData[rowIndex] || null,
this.currentFrameData[rowIndex + columnSize] || null,
];
if (this.nColumns === 3) {
row.push(this.currentFrameData[rowIndex + 2 * columnSize] || null)
}
return row;
});
},
isRowChecked(rowIndex) {
return this.rows[rowIndex].every(cell => cell && cell.checked);
},
isColumnChecked(colIndex) {
return this.rows.every(row => row[colIndex] && row[colIndex].checked);
},
toggleRow(rowIndex) {
const newState = !this.isRowChecked(rowIndex);
this.rows[rowIndex].forEach(cell => {
if (cell) cell.checked = newState;
});
this.updateTableValues();
},
toggleColumn(colIndex) {
const newState = !this.isColumnChecked(colIndex);
this.rows.forEach(row => {
if (row[colIndex]) row[colIndex].checked = newState;
});
this.updateTableValues();
},
// given time t, update the values in the html table with "data[t]"
updateTableValues(time) {
if (!this.colors) {
return;
}
let pc = (100 / this.video.duration) * (time === undefined ? this.video.currentTime : time);
if (isNaN(pc)) pc = 0;
const index = Math.floor(pc * this.dygraph.numRows() / 100);
// slice(1) to remove the timestamp point that we do not need
const labels = this.dygraph.getLabels().slice(1);
const values = this.dygraph.rawData_[index].slice(1);
const checkedNew = this.currentFrameData ? this.currentFrameData.map(cell => cell.checked) : Array(
this.colors.length).fill(true);
this.currentFrameData = labels.map((label, idx) => ({
label,
value: values[idx],
color: this.colors[idx],
checked: checkedNew[idx],
}));
const shouldUpdateVisibility = !this.checked.every((value, index) => value === checkedNew[index]);
if (shouldUpdateVisibility) {
this.checked = checkedNew;
this.dygraph.setVisibility(this.checked);
}
},
//#endregion
updateTimeQuery(time) {
let url = new URL(window.location.href);
let params = new URLSearchParams(url.search);
params.set("t", time);
url.search = params.toString();
window.history.replaceState({}, '', url.toString());
},
formatTime(time) {
var hours = Math.floor(time / 3600);
var minutes = Math.floor((time % 3600) / 60);
var seconds = Math.floor(time % 60);
return (hours > 0 ? hours + ':' : '') + (minutes < 10 ? '0' + minutes : minutes) + ':' + (seconds <
10 ?
'0' + seconds : seconds);
},
videoCanPlay() {
this.nVideoReadyToPlay += 1;
if(this.nVideoReadyToPlay == this.nVideos) {
// start autoplay all videos in sync
this.$refs.btnPlay.click();
}
}
};
}
</script>
</body>
</html>

Binary file not shown.

Before

Width:  |  Height:  |  Size: 416 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 446 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 326 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 312 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 469 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 318 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 420 KiB

BIN
media/koch/leader_rest.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 339 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

BIN
media/koch/leader_zero.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 484 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

170
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -192,6 +192,17 @@ charset-normalizer = ["charset-normalizer"]
html5lib = ["html5lib"]
lxml = ["lxml"]
[[package]]
name = "blinker"
version = "1.8.2"
description = "Fast, simple object-to-object and broadcast signaling"
optional = false
python-versions = ">=3.8"
files = [
{file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"},
{file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"},
]
[[package]]
name = "certifi"
version = "2024.7.4"
@@ -584,17 +595,6 @@ files = [
{file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"},
]
[[package]]
name = "decorator"
version = "4.4.2"
description = "Decorators for Humans"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*"
files = [
{file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
]
[[package]]
name = "deepdiff"
version = "7.0.1"
@@ -795,6 +795,7 @@ files = [
{file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"},
{file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"},
{file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"},
@@ -892,6 +893,28 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"]
typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "flask"
version = "3.0.3"
description = "A simple framework for building complex web applications."
optional = false
python-versions = ">=3.8"
files = [
{file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"},
{file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"},
]
[package.dependencies]
blinker = ">=1.6.2"
click = ">=8.1.3"
itsdangerous = ">=2.1.2"
Jinja2 = ">=3.1.2"
Werkzeug = ">=3.0.0"
[package.extras]
async = ["asgiref (>=3.2)"]
dotenv = ["python-dotenv"]
[[package]]
name = "frozenlist"
version = "1.4.1"
@@ -1350,6 +1373,7 @@ files = [
filelock = "*"
fsspec = ">=2023.5.0"
hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf-transfer\""}
InquirerPy = {version = "0.3.4", optional = true, markers = "extra == \"cli\""}
packaging = ">=20.9"
pyyaml = ">=5.1"
requests = "*"
@@ -1536,6 +1560,24 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "inquirerpy"
version = "0.3.4"
description = "Python port of Inquirer.js (A collection of common interactive command-line user interfaces)"
optional = false
python-versions = ">=3.7,<4.0"
files = [
{file = "InquirerPy-0.3.4-py3-none-any.whl", hash = "sha256:c65fdfbac1fa00e3ee4fb10679f4d3ed7a012abf4833910e63c295827fe2a7d4"},
{file = "InquirerPy-0.3.4.tar.gz", hash = "sha256:89d2ada0111f337483cb41ae31073108b2ec1e618a49d7110b0d7ade89fc197e"},
]
[package.dependencies]
pfzy = ">=0.3.1,<0.4.0"
prompt-toolkit = ">=3.0.1,<4.0.0"
[package.extras]
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
[[package]]
name = "intel-openmp"
version = "2021.4.0"
@@ -1550,6 +1592,17 @@ files = [
{file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
]
[[package]]
name = "itsdangerous"
version = "2.2.0"
description = "Safely pass data to untrusted environments and back."
optional = false
python-versions = ">=3.8"
files = [
{file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"},
{file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"},
]
[[package]]
name = "jinja2"
version = "3.1.4"
@@ -1741,9 +1794,13 @@ files = [
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
@@ -1901,30 +1958,6 @@ files = [
intel-openmp = "==2021.*"
tbb = "==2021.*"
[[package]]
name = "moviepy"
version = "1.0.3"
description = "Video editing with Python"
optional = false
python-versions = "*"
files = [
{file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"},
]
[package.dependencies]
decorator = ">=4.0.2,<5.0"
imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""}
imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""}
numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""}
proglog = "<=1.0.0"
requests = ">=2.8.1,<3.0"
tqdm = ">=4.11.2,<5.0"
[package.extras]
doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"]
optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"]
test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"]
[[package]]
name = "mpmath"
version = "1.3.0"
@@ -2550,6 +2583,20 @@ other = ["pillow (>=8.0.1)"]
sisl = ["box2d-py (==2.3.5)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "scipy (>=1.4.1)"]
testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-markdown-docs", "pytest-xdist"]
[[package]]
name = "pfzy"
version = "0.3.4"
description = "Python port of the fzy fuzzy string matching algorithm"
optional = false
python-versions = ">=3.7,<4.0"
files = [
{file = "pfzy-0.3.4-py3-none-any.whl", hash = "sha256:5f50d5b2b3207fa72e7ec0ef08372ef652685470974a107d0d4999fc5a903a96"},
{file = "pfzy-0.3.4.tar.gz", hash = "sha256:717ea765dd10b63618e7298b2d98efd819e0b30cd5905c9707223dceeb94b3f1"},
]
[package.extras]
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
[[package]]
name = "pillow"
version = "10.4.0"
@@ -2697,18 +2744,18 @@ pyyaml = ">=5.1"
virtualenv = ">=20.10.0"
[[package]]
name = "proglog"
version = "0.1.10"
description = "Log and progress bar manager for console, notebooks, web..."
name = "prompt-toolkit"
version = "3.0.47"
description = "Library for building powerful interactive command lines in Python"
optional = false
python-versions = "*"
python-versions = ">=3.7.0"
files = [
{file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"},
{file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"},
{file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"},
{file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"},
]
[package.dependencies]
tqdm = "*"
wcwidth = "*"
[[package]]
name = "protobuf"
@@ -3276,6 +3323,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -3809,13 +3857,13 @@ test = ["pytest"]
[[package]]
name = "setuptools"
version = "71.0.1"
version = "71.0.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "setuptools-71.0.1-py3-none-any.whl", hash = "sha256:1eb8ef012efae7f6acbc53ec0abde4bc6746c43087fd215ee09e1df48998711f"},
{file = "setuptools-71.0.1.tar.gz", hash = "sha256:c51d7fd29843aa18dad362d4b4ecd917022131425438251f4e3d766c964dd1ad"},
{file = "setuptools-71.0.0-py3-none-any.whl", hash = "sha256:f06fbe978a91819d250a30e0dc4ca79df713d909e24438a42d0ec300fc52247f"},
{file = "setuptools-71.0.0.tar.gz", hash = "sha256:98da3b8aca443b9848a209ae4165e2edede62633219afa493a58fbba57f72e2e"},
]
[package.extras]
@@ -4215,6 +4263,34 @@ perf = ["orjson"]
sweeps = ["sweeps (>=0.2.0)"]
workspaces = ["wandb-workspaces"]
[[package]]
name = "wcwidth"
version = "0.2.13"
description = "Measures the displayed width of unicode strings in a terminal"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"},
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
[[package]]
name = "werkzeug"
version = "3.0.3"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
files = [
{file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"},
{file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"},
]
[package.dependencies]
MarkupSafe = ">=2.1.1"
[package.extras]
watchdog = ["watchdog (>=2.3)"]
[[package]]
name = "xxhash"
version = "3.4.1"
@@ -4485,4 +4561,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb"
content-hash = "a340f2ed23db2f3c371c494cbc9a33392e122ed6713e6098277a87b3fb805f2b"

View File

@@ -43,7 +43,7 @@ opencv-python = ">=4.9.0"
diffusers = ">=0.27.2"
torchvision = ">=0.17.1"
h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer"], version = ">=0.23.0"}
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.23.0"}
gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1"
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
@@ -57,13 +57,15 @@ pytest-cov = {version = ">=5.0.0", optional = true}
datasets = ">=2.19.0"
imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1"
scikit-image = {version = ">=0.23.2", optional = true}
flask = ">=3.0.3"
pandas = {version = ">=2.2.2", optional = true}
scikit-image = {version = ">=0.23.2", optional = true}
dynamixel-sdk = {version = ">=3.7.31", optional = true}
pynput = {version = ">=1.7.7", optional = true}
# TODO(rcadene, salibert): 71.0.1 has a bug
setuptools = {version = "!=71.0.1", optional = true}

View File

@@ -15,7 +15,9 @@
# limitations under the License.
import pytest
from .utils import DEVICE
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEVICE, KOCH_ROBOT_CONFIG_PATH
def pytest_collection_finish():
@@ -27,11 +29,12 @@ def is_koch_available():
try:
from lerobot.common.robot_devices.robots.factory import make_robot
robot = make_robot("koch")
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
robot = make_robot(robot_cfg)
robot.connect()
del robot
return True
except Exception as e:
print("An alexander koch robot is not available.")
print("A koch robot is not available.")
print(e)
return False

View File

@@ -84,5 +84,7 @@ if __name__ == "__main__":
"lerobot/pusht",
"lerobot/aloha_sim_insertion_human",
"lerobot/xarm_lift_medium",
"lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch",
]:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)

View File

@@ -22,7 +22,6 @@ from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.utils import DEFAULT_CONFIG_PATH
@@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)
dataloader = torch.utils.data.DataLoader(
dataset,

View File

@@ -3,13 +3,20 @@ from pathlib import Path
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.control_robot import record_dataset, replay_episode, run_policy, teleoperate
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_koch
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, KOCH_ROBOT_CONFIG_PATH, require_koch
def make_robot_(overrides=None):
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH, overrides)
robot = make_robot(robot_cfg)
return robot
@require_koch
# `require_koch` uses `request` to access `is_koch_available` fixture
def test_teleoperate(request):
robot = make_robot("koch")
robot = make_robot_()
teleoperate(robot, teleop_time_s=1)
teleoperate(robot, fps=30, teleop_time_s=1)
teleoperate(robot, fps=60, teleop_time_s=1)
@@ -17,20 +24,35 @@ def test_teleoperate(request):
@require_koch
def test_record_dataset_and_replay_episode_and_run_policy(tmpdir, request):
robot_name = "koch"
def test_calibrate(request):
robot = make_robot_()
calibrate(robot)
del robot
@require_koch
def test_record_without_cameras(tmpdir, request):
root = Path(tmpdir)
repo_id = "lerobot/debug"
robot = make_robot_(overrides=["~cameras"])
record(robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2)
@require_koch
def test_record_and_replay_and_policy(tmpdir, request):
env_name = "koch_real"
policy_name = "act_koch_real"
root = Path(tmpdir)
repo_id = "lerobot/debug"
robot = make_robot(robot_name)
dataset = record_dataset(
robot = make_robot_()
dataset = record(
robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2
)
replay_episode(robot, episode=0, fps=30, root=root, repo_id=repo_id)
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
@@ -43,6 +65,6 @@ def test_record_dataset_and_replay_episode_and_run_policy(tmpdir, request):
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
run_policy(robot, policy, cfg, run_time_s=1)
record(robot, policy, cfg, run_time_s=1)
del robot

View File

@@ -23,6 +23,7 @@ import einops
import pytest
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from safetensors.torch import load_file
import lerobot
@@ -34,6 +35,7 @@ from lerobot.common.datasets.compute_stats import (
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import (
create_branch,
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
@@ -301,6 +303,9 @@ def test_flatten_unflatten_dict():
"lerobot/pusht",
"lerobot/aloha_sim_insertion_human",
"lerobot/xarm_lift_medium",
# (michel-aractingi) commenting the two datasets from openx as test is failing
# "lerobot/nyu_franka_play_dataset",
# "lerobot/cmu_stretch",
],
)
def test_backward_compatibility(repo_id):
@@ -316,6 +321,11 @@ def test_backward_compatibility(repo_id):
new_frame = dataset[i] # noqa: B023
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
# ignore language instructions (if exists) in language conditioned datasets
# TODO (michel-aractingi): transform language obs to langauge embeddings via tokenizer
new_frame.pop("language_instruction", None)
old_frame.pop("language_instruction", None)
new_keys = set(new_frame.keys())
old_keys = set(old_frame.keys())
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
@@ -385,3 +395,29 @@ def test_aggregate_stats():
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
@pytest.mark.skip("Requires internet access")
def test_create_branch():
api = HfApi()
repo_id = "cadene/test_create_branch"
repo_type = "dataset"
branch = "test"
ref = f"refs/heads/{branch}"
# Prepare a repo with a test branch
api.delete_repo(repo_id, repo_type=repo_type, missing_ok=True)
api.create_repo(repo_id, repo_type=repo_type)
create_branch(repo_id, repo_type=repo_type, branch=branch)
# Make sure the test branch exists
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
refs = [branch.ref for branch in branches]
assert ref in refs
# Overwrite it
create_branch(repo_id, repo_type=repo_type, branch=branch)
# Clean
api.delete_repo(repo_id, repo_type=repo_type)

View File

@@ -1,33 +1,54 @@
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): test calibration
# TODO(rcadene): add compatibility with other motors bus
import time
import hydra
import numpy as np
import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from tests.utils import require_koch
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import KOCH_ROBOT_CONFIG_PATH, require_koch
def make_motors_bus():
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
# Instantiating a common motors structure.
# Here the one from Alexander Koch follower arm.
motors_bus = hydra.utils.instantiate(robot_cfg.leader_arms.main)
return motors_bus
@require_koch
def test_find_port(request):
from lerobot.common.robot_devices.motors.dynamixel import find_port
find_port()
@require_koch
def test_configure_motors_all_ids_1(request):
# This test expect the configuration was already correct.
motors_bus = make_motors_bus()
motors_bus.connect()
motors_bus.write("Baud_Rate", [0] * len(motors_bus.motors))
motors_bus.set_bus_baudrate(9_600)
motors_bus.write("ID", [1] * len(motors_bus.motors))
del motors_bus
# Test configure
motors_bus = make_motors_bus()
motors_bus.connect()
assert motors_bus.are_motors_configured()
del motors_bus
@require_koch
def test_motors_bus(request):
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): test calibration
# TODO(rcadene): add compatibility with other motors bus
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
# Test instantiating a common motors structure.
# Here the one from Alexander Koch follower arm.
port = "/dev/tty.usbmodem575E0032081"
motors = {
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
}
motors_bus = DynamixelMotorsBus(port, motors)
motors_bus = make_motors_bus()
# Test reading and writting before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
@@ -41,7 +62,7 @@ def test_motors_bus(request):
del motors_bus
# Test connecting
motors_bus = DynamixelMotorsBus(port, motors)
motors_bus = make_motors_bus()
motors_bus.connect()
# Test connecting twice raises an error
@@ -52,7 +73,7 @@ def test_motors_bus(request):
motors_bus.write("Torque_Enable", 0)
values = motors_bus.read("Torque_Enable")
assert isinstance(values, np.ndarray)
assert len(values) == len(motors)
assert len(values) == len(motors_bus.motors)
assert (values == 0).all()
# Test writing torque on a specific motor
@@ -83,10 +104,3 @@ def test_motors_bus(request):
time.sleep(1)
new_values = motors_bus.read("Present_Position")
assert (new_values == values).all()
@require_koch
def test_find_port(request):
from lerobot.common.robot_devices.motors.dynamixel import find_port
find_port()

View File

@@ -37,7 +37,6 @@ from lerobot.common.policies.factory import (
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -214,7 +213,7 @@ def test_act_backbone_lr():
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
optimizer, _ = policy.make_optimizer_and_scheduler(cfg)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone

View File

@@ -1,5 +1,6 @@
import random
from typing import Callable
from uuid import uuid4
import numpy as np
import pytest
@@ -13,6 +14,7 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.utils.utils import (
get_global_random_state,
init_hydra_config,
seeded_context,
set_global_random_state,
set_global_seed,
@@ -83,3 +85,10 @@ def test_reset_episode_index():
correct_episode_index = [0, 0, 1, 2, 2, 2]
dataset = reset_episode_index(dataset)
assert dataset["episode_index"] == correct_episode_index
def test_init_hydra_config_empty():
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
with open(test_file, "w") as f:
f.write("\n")
init_hydra_config(test_file)

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
def test_visualize_dataset_html(tmpdir, repo_id):
tmpdir = Path(tmpdir)
visualize_dataset_html(
repo_id,
episodes=[0],
output_dir=tmpdir,
serve=False,
)
assert (tmpdir / "static" / "episode_0.csv").exists()

View File

@@ -23,6 +23,7 @@ from lerobot.common.utils.import_utils import is_package_available
# Pass this as the first argument to init_hydra_config.
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
KOCH_ROBOT_CONFIG_PATH = "lerobot/configs/robot/koch.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -161,6 +162,7 @@ def require_koch(func):
if request is None:
raise ValueError("The 'request' fixture must be passed to the test function as a parameter.")
# The function `is_koch_available` is defined in `tests/conftest.py`
if not request.getfixturevalue("is_koch_available"):
pytest.skip("An alexander koch robot is not available.")
return func(*args, **kwargs)