forked from tangger/lerobot
Merge branch 'main' into user/michel-aractingi/2024-11-27-port-hil-serl
This commit is contained in:
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -21,7 +21,7 @@ Provide a simple way for the reviewer to try out your changes.
|
|||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
|
pytest -sx tests/test_stuff.py::test_something
|
||||||
```
|
```
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py --some.option=true
|
python lerobot/scripts/train.py --some.option=true
|
||||||
|
|||||||
8
.github/workflows/nightly-tests.yml
vendored
8
.github/workflows/nightly-tests.yml
vendored
@@ -7,10 +7,8 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: "0 2 * * *"
|
- cron: "0 2 * * *"
|
||||||
|
|
||||||
env:
|
# env:
|
||||||
DATA_DIR: tests/data
|
|
||||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_all_tests_cpu:
|
run_all_tests_cpu:
|
||||||
name: CPU
|
name: CPU
|
||||||
@@ -30,13 +28,9 @@ jobs:
|
|||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
- name: Tests
|
- name: Tests
|
||||||
env:
|
|
||||||
DATA_DIR: tests/data
|
|
||||||
run: pytest -v --cov=./lerobot --disable-warnings tests
|
run: pytest -v --cov=./lerobot --disable-warnings tests
|
||||||
|
|
||||||
- name: Tests end-to-end
|
- name: Tests end-to-end
|
||||||
env:
|
|
||||||
DATA_DIR: tests/data
|
|
||||||
run: make test-end-to-end
|
run: make test-end-to-end
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
63
.github/workflows/test.yml
vendored
63
.github/workflows/test.yml
vendored
@@ -29,7 +29,6 @@ jobs:
|
|||||||
name: Pytest
|
name: Pytest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
DATA_DIR: tests/data
|
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -70,7 +69,6 @@ jobs:
|
|||||||
name: Pytest (minimal install)
|
name: Pytest (minimal install)
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
DATA_DIR: tests/data
|
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -104,39 +102,38 @@ jobs:
|
|||||||
&& rm -rf tests/outputs outputs
|
&& rm -rf tests/outputs outputs
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
|
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
|
||||||
end-to-end:
|
# end-to-end:
|
||||||
name: End-to-end
|
# name: End-to-end
|
||||||
runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
env:
|
# env:
|
||||||
DATA_DIR: tests/data
|
# MUJOCO_GL: egl
|
||||||
MUJOCO_GL: egl
|
# steps:
|
||||||
steps:
|
# - uses: actions/checkout@v4
|
||||||
- uses: actions/checkout@v4
|
# with:
|
||||||
with:
|
# lfs: true # Ensure LFS files are pulled
|
||||||
lfs: true # Ensure LFS files are pulled
|
|
||||||
|
|
||||||
- name: Install apt dependencies
|
# - name: Install apt dependencies
|
||||||
# portaudio19-dev is needed to install pyaudio
|
# # portaudio19-dev is needed to install pyaudio
|
||||||
run: |
|
# run: |
|
||||||
sudo apt-get update && \
|
# sudo apt-get update && \
|
||||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||||
|
|
||||||
- name: Install poetry
|
# - name: Install poetry
|
||||||
run: |
|
# run: |
|
||||||
pipx install poetry && poetry config virtualenvs.in-project true
|
# pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
- name: Set up Python 3.10
|
# - name: Set up Python 3.10
|
||||||
uses: actions/setup-python@v5
|
# uses: actions/setup-python@v5
|
||||||
with:
|
# with:
|
||||||
python-version: "3.10"
|
# python-version: "3.10"
|
||||||
cache: "poetry"
|
# cache: "poetry"
|
||||||
|
|
||||||
- name: Install poetry dependencies
|
# - name: Install poetry dependencies
|
||||||
run: |
|
# run: |
|
||||||
poetry install --all-extras
|
# poetry install --all-extras
|
||||||
|
|
||||||
- name: Test end-to-end
|
# - name: Test end-to-end
|
||||||
run: |
|
# run: |
|
||||||
make test-end-to-end \
|
# make test-end-to-end \
|
||||||
&& rm -rf outputs
|
# && rm -rf outputs
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ We use `pytest` in order to run the tests. From the root of the
|
|||||||
repository, here's how to run tests with `pytest` for the library:
|
repository, here's how to run tests with `pytest` for the library:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR="tests/data" python -m pytest -sv ./tests
|
python -m pytest -sv ./tests
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -153,10 +153,12 @@ python lerobot/scripts/visualize_dataset.py \
|
|||||||
--episode-index 0
|
--episode-index 0
|
||||||
```
|
```
|
||||||
|
|
||||||
or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
--repo-id lerobot/pusht \
|
--repo-id lerobot/pusht \
|
||||||
|
--root ./my_local_data_dir \
|
||||||
|
--local-files-only 1 \
|
||||||
--episode-index 0
|
--episode-index 0
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -208,12 +210,10 @@ dataset attributes:
|
|||||||
|
|
||||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||||
- videos are stored in mp4 format to save space or png files
|
- videos are stored in mp4 format to save space
|
||||||
- episode_data_index saved using `safetensor` tensor serialization format
|
- metadata are stored in plain json/jsonl files
|
||||||
- stats saved using `safetensor` tensor serialization format
|
|
||||||
- info are saved using JSON
|
|
||||||
|
|
||||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.
|
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||||
|
|
||||||
### Evaluate a pretrained policy
|
### Evaluate a pretrained policy
|
||||||
|
|
||||||
|
|||||||
@@ -192,7 +192,6 @@ Record 2 episodes and upload your dataset to the hub:
|
|||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--robot-path lerobot/configs/robot/so100.yaml \
|
--robot-path lerobot/configs/robot/so100.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/so100_test \
|
--repo-id ${HF_USER}/so100_test \
|
||||||
--tags so100 tutorial \
|
--tags so100 tutorial \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
@@ -212,7 +211,6 @@ echo ${HF_USER}/so100_test
|
|||||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset_html.py \
|
python lerobot/scripts/visualize_dataset_html.py \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/so100_test
|
--repo-id ${HF_USER}/so100_test
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -220,10 +218,9 @@ python lerobot/scripts/visualize_dataset_html.py \
|
|||||||
|
|
||||||
Now try to replay the first episode on your robot:
|
Now try to replay the first episode on your robot:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
python lerobot/scripts/control_robot.py replay \
|
||||||
--robot-path lerobot/configs/robot/so100.yaml \
|
--robot-path lerobot/configs/robot/so100.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/so100_test \
|
--repo-id ${HF_USER}/so100_test \
|
||||||
--episode 0
|
--episode 0
|
||||||
```
|
```
|
||||||
@@ -232,7 +229,7 @@ DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
|||||||
|
|
||||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=data python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
dataset_repo_id=${HF_USER}/so100_test \
|
dataset_repo_id=${HF_USER}/so100_test \
|
||||||
policy=act_so100_real \
|
policy=act_so100_real \
|
||||||
env=so100_real \
|
env=so100_real \
|
||||||
@@ -248,7 +245,6 @@ Let's explain it:
|
|||||||
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
|
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||||
|
|
||||||
@@ -259,7 +255,6 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
|
|||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--robot-path lerobot/configs/robot/so100.yaml \
|
--robot-path lerobot/configs/robot/so100.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/eval_act_so100_test \
|
--repo-id ${HF_USER}/eval_act_so100_test \
|
||||||
--tags so100 tutorial eval \
|
--tags so100 tutorial eval \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
|
|||||||
@@ -192,7 +192,6 @@ Record 2 episodes and upload your dataset to the hub:
|
|||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--robot-path lerobot/configs/robot/moss.yaml \
|
--robot-path lerobot/configs/robot/moss.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/moss_test \
|
--repo-id ${HF_USER}/moss_test \
|
||||||
--tags moss tutorial \
|
--tags moss tutorial \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
@@ -212,7 +211,6 @@ echo ${HF_USER}/moss_test
|
|||||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset_html.py \
|
python lerobot/scripts/visualize_dataset_html.py \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/moss_test
|
--repo-id ${HF_USER}/moss_test
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -220,10 +218,9 @@ python lerobot/scripts/visualize_dataset_html.py \
|
|||||||
|
|
||||||
Now try to replay the first episode on your robot:
|
Now try to replay the first episode on your robot:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
python lerobot/scripts/control_robot.py replay \
|
||||||
--robot-path lerobot/configs/robot/moss.yaml \
|
--robot-path lerobot/configs/robot/moss.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/moss_test \
|
--repo-id ${HF_USER}/moss_test \
|
||||||
--episode 0
|
--episode 0
|
||||||
```
|
```
|
||||||
@@ -232,7 +229,7 @@ DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
|||||||
|
|
||||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=data python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
dataset_repo_id=${HF_USER}/moss_test \
|
dataset_repo_id=${HF_USER}/moss_test \
|
||||||
policy=act_moss_real \
|
policy=act_moss_real \
|
||||||
env=moss_real \
|
env=moss_real \
|
||||||
@@ -248,7 +245,6 @@ Let's explain it:
|
|||||||
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
|
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||||
|
|
||||||
@@ -259,7 +255,6 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
|
|||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--robot-path lerobot/configs/robot/moss.yaml \
|
--robot-path lerobot/configs/robot/moss.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/eval_act_moss_test \
|
--repo-id ${HF_USER}/eval_act_moss_test \
|
||||||
--tags moss tutorial eval \
|
--tags moss tutorial eval \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ For a visual walkthrough of the assembly process, you can refer to [this video t
|
|||||||
|
|
||||||
## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1
|
## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1
|
||||||
|
|
||||||
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands.
|
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands (make sure gcc is installed).
|
||||||
|
|
||||||
Using `pip`:
|
Using `pip`:
|
||||||
```bash
|
```bash
|
||||||
@@ -778,7 +778,6 @@ Now run this to record 2 episodes:
|
|||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--robot-path lerobot/configs/robot/koch.yaml \
|
--robot-path lerobot/configs/robot/koch.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/koch_test \
|
--repo-id ${HF_USER}/koch_test \
|
||||||
--tags tutorial \
|
--tags tutorial \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
@@ -787,7 +786,7 @@ python lerobot/scripts/control_robot.py record \
|
|||||||
--num-episodes 2
|
--num-episodes 2
|
||||||
```
|
```
|
||||||
|
|
||||||
This will write your dataset locally to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||||
|
|
||||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||||
|
|
||||||
@@ -840,7 +839,6 @@ In the coming months, we plan to release a foundational model for robotics. We a
|
|||||||
You can visualize your dataset by running:
|
You can visualize your dataset by running:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset_html.py \
|
python lerobot/scripts/visualize_dataset_html.py \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/koch_test
|
--repo-id ${HF_USER}/koch_test
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -858,7 +856,6 @@ To replay the first episode of the dataset you just recorded, run the following
|
|||||||
python lerobot/scripts/control_robot.py replay \
|
python lerobot/scripts/control_robot.py replay \
|
||||||
--robot-path lerobot/configs/robot/koch.yaml \
|
--robot-path lerobot/configs/robot/koch.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/koch_test \
|
--repo-id ${HF_USER}/koch_test \
|
||||||
--episode 0
|
--episode 0
|
||||||
```
|
```
|
||||||
@@ -871,7 +868,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
|||||||
|
|
||||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=data python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
dataset_repo_id=${HF_USER}/koch_test \
|
dataset_repo_id=${HF_USER}/koch_test \
|
||||||
policy=act_koch_real \
|
policy=act_koch_real \
|
||||||
env=koch_real \
|
env=koch_real \
|
||||||
@@ -918,7 +915,6 @@ env:
|
|||||||
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
|
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
|
||||||
|
|
||||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||||
|
|
||||||
@@ -991,7 +987,6 @@ To this end, you can use the `record` function from [`lerobot/scripts/control_ro
|
|||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--robot-path lerobot/configs/robot/koch.yaml \
|
--robot-path lerobot/configs/robot/koch.yaml \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/eval_koch_test \
|
--repo-id ${HF_USER}/eval_koch_test \
|
||||||
--tags tutorial eval \
|
--tags tutorial eval \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
@@ -1010,7 +1005,6 @@ As you can see, it's almost the same command as previously used to record your t
|
|||||||
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset.py \
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/eval_koch_test
|
--repo-id ${HF_USER}/eval_koch_test
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,6 @@ Record one episode:
|
|||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||||
--fps 20 \
|
--fps 20 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/stretch_test \
|
--repo-id ${HF_USER}/stretch_test \
|
||||||
--tags stretch tutorial \
|
--tags stretch tutorial \
|
||||||
--warmup-time-s 3 \
|
--warmup-time-s 3 \
|
||||||
@@ -146,7 +145,6 @@ Now try to replay this episode (make sure the robot's initial position is the sa
|
|||||||
python lerobot/scripts/control_robot.py replay \
|
python lerobot/scripts/control_robot.py replay \
|
||||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||||
--fps 20 \
|
--fps 20 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/stretch_test \
|
--repo-id ${HF_USER}/stretch_test \
|
||||||
--episode 0
|
--episode 0
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -84,7 +84,6 @@ python lerobot/scripts/control_robot.py record \
|
|||||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||||
--robot-overrides max_relative_target=null \
|
--robot-overrides max_relative_target=null \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/aloha_test \
|
--repo-id ${HF_USER}/aloha_test \
|
||||||
--tags aloha tutorial \
|
--tags aloha tutorial \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
@@ -104,7 +103,6 @@ echo ${HF_USER}/aloha_test
|
|||||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset_html.py \
|
python lerobot/scripts/visualize_dataset_html.py \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/aloha_test
|
--repo-id ${HF_USER}/aloha_test
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -119,7 +117,6 @@ python lerobot/scripts/control_robot.py replay \
|
|||||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||||
--robot-overrides max_relative_target=null \
|
--robot-overrides max_relative_target=null \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/aloha_test \
|
--repo-id ${HF_USER}/aloha_test \
|
||||||
--episode 0
|
--episode 0
|
||||||
```
|
```
|
||||||
@@ -128,7 +125,7 @@ python lerobot/scripts/control_robot.py replay \
|
|||||||
|
|
||||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=data python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
dataset_repo_id=${HF_USER}/aloha_test \
|
dataset_repo_id=${HF_USER}/aloha_test \
|
||||||
policy=act_aloha_real \
|
policy=act_aloha_real \
|
||||||
env=aloha_real \
|
env=aloha_real \
|
||||||
@@ -144,7 +141,6 @@ Let's explain it:
|
|||||||
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
|
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
|
||||||
4. We provided `device=cuda` since we are training on a Nvidia GPU.
|
4. We provided `device=cuda` since we are training on a Nvidia GPU.
|
||||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
||||||
|
|
||||||
@@ -156,7 +152,6 @@ python lerobot/scripts/control_robot.py record \
|
|||||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||||
--robot-overrides max_relative_target=null \
|
--robot-overrides max_relative_target=null \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id ${HF_USER}/eval_act_aloha_test \
|
--repo-id ${HF_USER}/eval_act_aloha_test \
|
||||||
--tags aloha tutorial eval \
|
--tags aloha tutorial eval \
|
||||||
--warmup-time-s 5 \
|
--warmup-time-s 5 \
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ def build_features(mode: str) -> dict:
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def load_raw_dataset(zarr_path: Path, load_images: bool = True):
|
def load_raw_dataset(zarr_path: Path):
|
||||||
try:
|
try:
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||||
|
|||||||
@@ -291,14 +291,22 @@ class LeRobotDatasetMetadata:
|
|||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
if robot is not None:
|
if robot is not None:
|
||||||
|
<<<<<<< HEAD
|
||||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||||
|
=======
|
||||||
|
features = get_features_from_robot(robot, use_videos)
|
||||||
|
>>>>>>> main
|
||||||
robot_type = robot.robot_type
|
robot_type = robot.robot_type
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||||
)
|
)
|
||||||
|
<<<<<<< HEAD
|
||||||
elif robot_type is None or features is None:
|
elif robot_type is None or features is None:
|
||||||
|
=======
|
||||||
|
elif features is None:
|
||||||
|
>>>>>>> main
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,639 +0,0 @@
|
|||||||
OPENX_DATASET_CONFIGS:
|
|
||||||
fractal20220817_data:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- base_pose_tool_reached
|
|
||||||
- gripper_closed
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
kuka:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- clip_function_input/base_pose_tool_reached
|
|
||||||
- gripper_closed
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
bridge_openx:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- EEF_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
taco_play:
|
|
||||||
image_obs_keys:
|
|
||||||
- rgb_static
|
|
||||||
- rgb_gripper
|
|
||||||
depth_obs_keys:
|
|
||||||
- depth_static
|
|
||||||
- depth_gripper
|
|
||||||
state_obs_keys:
|
|
||||||
- state_eef
|
|
||||||
- state_gripper
|
|
||||||
fps: 15
|
|
||||||
|
|
||||||
jaco_play:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- image_wrist
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state_eef
|
|
||||||
- state_gripper
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
berkeley_cable_routing:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- top_image
|
|
||||||
- wrist45_image
|
|
||||||
- wrist225_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- robot_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
roboturk:
|
|
||||||
image_obs_keys:
|
|
||||||
- front_rgb
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- null
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
nyu_door_opening_surprising_effectiveness:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- null
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
viola:
|
|
||||||
image_obs_keys:
|
|
||||||
- agentview_rgb
|
|
||||||
- eye_in_hand_rgb
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- joint_states
|
|
||||||
- gripper_states
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
berkeley_autolab_ur5:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- hand_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- image_with_depth
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
toto:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 30
|
|
||||||
|
|
||||||
language_table:
|
|
||||||
image_obs_keys:
|
|
||||||
- rgb
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- effector_translation
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
columbia_cairlab_pusht_real:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- robot_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
stanford_kuka_multimodal_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- depth_image
|
|
||||||
state_obs_keys:
|
|
||||||
- ee_position
|
|
||||||
- ee_orientation
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
nyu_rot_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
io_ai_tech:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- image_fisheye
|
|
||||||
- image_left_side
|
|
||||||
- image_right_side
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
stanford_hydra_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
austin_buds_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
nyu_franka_play_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- image_additional_view
|
|
||||||
depth_obs_keys:
|
|
||||||
- depth
|
|
||||||
- depth_additional_view
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
maniskill_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- depth
|
|
||||||
- wrist_depth
|
|
||||||
state_obs_keys:
|
|
||||||
- tcp_pose
|
|
||||||
- gripper_state
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
furniture_bench_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
cmu_franka_exploration_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- highres_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- null
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
ucsd_kitchen_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- joint_state
|
|
||||||
fps: 2
|
|
||||||
|
|
||||||
ucsd_pick_and_place_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
spoc:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- image_manipulation
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- null
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
austin_sailor_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
austin_sirius_dataset_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
bc_z:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- present/xyz
|
|
||||||
- present/axis_angle
|
|
||||||
- present/sensed_close
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
utokyo_pr2_opening_fridge_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
utokyo_xarm_pick_and_place_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- image2
|
|
||||||
- hand_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- end_effector_pose
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
utokyo_xarm_bimanual_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- pose_r
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
robo_net:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- image1
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 1
|
|
||||||
|
|
||||||
robo_set:
|
|
||||||
image_obs_keys:
|
|
||||||
- image_left
|
|
||||||
- image_right
|
|
||||||
- image_wrist
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
- state_velocity
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
berkeley_mvp_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- hand_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- gripper
|
|
||||||
- pose
|
|
||||||
- joint_pos
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
berkeley_rpt_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- hand_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- joint_pos
|
|
||||||
- gripper
|
|
||||||
fps: 30
|
|
||||||
|
|
||||||
kaist_nonprehensile_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
stanford_mask_vit_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
|
|
||||||
tokyo_u_lsmo_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
dlr_sara_pour_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
dlr_sara_grid_clamp_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
dlr_edan_shared_control_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
asu_table_top_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 12.5
|
|
||||||
|
|
||||||
stanford_robocook_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image_1
|
|
||||||
- image_2
|
|
||||||
depth_obs_keys:
|
|
||||||
- depth_1
|
|
||||||
- depth_2
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
imperialcollege_sawyer_wrist_cam:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
iamlab_cmu_pickup_insert_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- joint_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
uiuc_d3field:
|
|
||||||
image_obs_keys:
|
|
||||||
- image_1
|
|
||||||
- image_2
|
|
||||||
depth_obs_keys:
|
|
||||||
- depth_1
|
|
||||||
- depth_2
|
|
||||||
state_obs_keys:
|
|
||||||
- null
|
|
||||||
fps: 1
|
|
||||||
|
|
||||||
utaustin_mutex:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
berkeley_fanuc_manipulation:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- joint_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
cmu_playing_with_food:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- finger_vision_1
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
cmu_play_fusion:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
cmu_stretch:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- eef_state
|
|
||||||
- gripper_state
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
berkeley_gnm_recon:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
- position
|
|
||||||
- yaw
|
|
||||||
fps: 3
|
|
||||||
|
|
||||||
berkeley_gnm_cory_hall:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
- position
|
|
||||||
- yaw
|
|
||||||
fps: 5
|
|
||||||
|
|
||||||
berkeley_gnm_sac_son:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
- position
|
|
||||||
- yaw
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
droid:
|
|
||||||
image_obs_keys:
|
|
||||||
- exterior_image_1_left
|
|
||||||
- exterior_image_2_left
|
|
||||||
- wrist_image_left
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- proprio
|
|
||||||
fps: 15
|
|
||||||
|
|
||||||
droid_100:
|
|
||||||
image_obs_keys:
|
|
||||||
- exterior_image_1_left
|
|
||||||
- exterior_image_2_left
|
|
||||||
- wrist_image_left
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- proprio
|
|
||||||
fps: 15
|
|
||||||
|
|
||||||
fmb:
|
|
||||||
image_obs_keys:
|
|
||||||
- image_side_1
|
|
||||||
- image_side_2
|
|
||||||
- image_wrist_1
|
|
||||||
- image_wrist_2
|
|
||||||
depth_obs_keys:
|
|
||||||
- image_side_1_depth
|
|
||||||
- image_side_2_depth
|
|
||||||
- image_wrist_1_depth
|
|
||||||
- image_wrist_2_depth
|
|
||||||
state_obs_keys:
|
|
||||||
- proprio
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
dobbe:
|
|
||||||
image_obs_keys:
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- proprio
|
|
||||||
fps: 3.75
|
|
||||||
|
|
||||||
usc_cloth_sim_converted_externally_to_rlds:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- null
|
|
||||||
fps: 10
|
|
||||||
|
|
||||||
plex_robosuite:
|
|
||||||
image_obs_keys:
|
|
||||||
- image
|
|
||||||
- wrist_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 20
|
|
||||||
|
|
||||||
conq_hose_manipulation:
|
|
||||||
image_obs_keys:
|
|
||||||
- frontleft_fisheye_image
|
|
||||||
- frontright_fisheye_image
|
|
||||||
- hand_color_image
|
|
||||||
depth_obs_keys:
|
|
||||||
- null
|
|
||||||
state_obs_keys:
|
|
||||||
- state
|
|
||||||
fps: 30
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the Licens e.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
NOTE(YL): Adapted from:
|
|
||||||
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
|
|
||||||
|
|
||||||
data_utils.py
|
|
||||||
|
|
||||||
Additional utils for data processing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
|
|
||||||
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Converts gripper actions from continuous to binary values (0 and 1).
|
|
||||||
|
|
||||||
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
|
|
||||||
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
|
|
||||||
values based on the state that is reached _after_ those intermediate values.
|
|
||||||
|
|
||||||
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
|
|
||||||
chunk of intermediate values as the last action in the trajectory.
|
|
||||||
|
|
||||||
The `scan_fn` implements the following logic:
|
|
||||||
new_actions = np.empty_like(actions)
|
|
||||||
carry = actions[-1]
|
|
||||||
for i in reversed(range(actions.shape[0])):
|
|
||||||
if in_between_mask[i]:
|
|
||||||
carry = carry
|
|
||||||
else:
|
|
||||||
carry = float(open_mask[i])
|
|
||||||
new_actions[i] = carry
|
|
||||||
"""
|
|
||||||
open_mask, closed_mask = actions > 0.95, actions < 0.05
|
|
||||||
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
|
|
||||||
is_open_float = tf.cast(open_mask, tf.float32)
|
|
||||||
|
|
||||||
def scan_fn(carry, i):
|
|
||||||
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
|
|
||||||
|
|
||||||
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
|
|
||||||
|
|
||||||
|
|
||||||
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
|
||||||
return 1 - actions
|
|
||||||
|
|
||||||
|
|
||||||
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
|
||||||
"""
|
|
||||||
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
|
|
||||||
|
|
||||||
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
|
|
||||||
"""
|
|
||||||
# Note =>> -1 for closing, 1 for opening, 0 for no change
|
|
||||||
opening_mask, closing_mask = actions < -0.1, actions > 0.1
|
|
||||||
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
|
|
||||||
|
|
||||||
def scan_fn(carry, i):
|
|
||||||
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
|
|
||||||
|
|
||||||
# If no relative grasp, assumes open for whole trajectory
|
|
||||||
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
|
|
||||||
start = tf.cond(start == 0, lambda: 1, lambda: start)
|
|
||||||
|
|
||||||
# Note =>> -1 for closed, 1 for open
|
|
||||||
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
|
|
||||||
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
|
|
||||||
|
|
||||||
return new_actions
|
|
||||||
|
|
||||||
|
|
||||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
|
||||||
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
|
||||||
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
|
|
||||||
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
|
|
||||||
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
|
|
||||||
|
|
||||||
return traj_truncated
|
|
||||||
|
|
||||||
|
|
||||||
# === RLDS Dataset Initialization Utilities ===
|
|
||||||
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
|
|
||||||
print("\n######################################################################################")
|
|
||||||
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
|
|
||||||
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
|
|
||||||
pad = 80 - len(dataset_kwargs["name"])
|
|
||||||
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
|
|
||||||
print("######################################################################################\n")
|
|
||||||
@@ -1,200 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
NOTE(YL): Adapted from:
|
|
||||||
OpenVLA: https://github.com/openvla/openvla
|
|
||||||
|
|
||||||
Episode transforms for DROID dataset.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow_graphics.geometry.transformation as tfg
|
|
||||||
|
|
||||||
|
|
||||||
def rmat_to_euler(rot_mat):
|
|
||||||
return tfg.euler.from_rotation_matrix(rot_mat)
|
|
||||||
|
|
||||||
|
|
||||||
def euler_to_rmat(euler):
|
|
||||||
return tfg.rotation_matrix_3d.from_euler(euler)
|
|
||||||
|
|
||||||
|
|
||||||
def invert_rmat(rot_mat):
|
|
||||||
return tfg.rotation_matrix_3d.inverse(rot_mat)
|
|
||||||
|
|
||||||
|
|
||||||
def rotmat_to_rot6d(mat):
|
|
||||||
"""
|
|
||||||
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
|
|
||||||
Args:
|
|
||||||
mat: rotation matrix
|
|
||||||
|
|
||||||
Returns: 6d vector (first two rows of rotation matrix)
|
|
||||||
|
|
||||||
"""
|
|
||||||
r6 = mat[..., :2, :]
|
|
||||||
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
|
|
||||||
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
|
|
||||||
return r6_flat
|
|
||||||
|
|
||||||
|
|
||||||
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
|
|
||||||
"""
|
|
||||||
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
|
|
||||||
Args:
|
|
||||||
velocity: 6d velocity action (3 x translation, 3 x rotation)
|
|
||||||
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
|
|
||||||
|
|
||||||
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
|
||||||
|
|
||||||
"""
|
|
||||||
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
|
||||||
r_frame_inv = invert_rmat(r_frame)
|
|
||||||
|
|
||||||
# world to wrist: dT_pi = R^-1 dT_rbt
|
|
||||||
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
|
||||||
|
|
||||||
# world to wrist: dR_pi = R^-1 dR_rbt R
|
|
||||||
dr_ = euler_to_rmat(velocity[:, 3:6])
|
|
||||||
dr_ = r_frame_inv @ (dr_ @ r_frame)
|
|
||||||
dr_r6 = rotmat_to_rot6d(dr_)
|
|
||||||
return tf.concat([vel_t, dr_r6], axis=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def rand_swap_exterior_images(img1, img2):
|
|
||||||
"""
|
|
||||||
Randomly swaps the two exterior images (for training with single exterior input).
|
|
||||||
"""
|
|
||||||
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
|
|
||||||
|
|
||||||
|
|
||||||
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
|
||||||
"""
|
|
||||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
|
||||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
dt,
|
|
||||||
dr_,
|
|
||||||
1 - trajectory["action_dict"]["gripper_position"],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
|
||||||
rand_swap_exterior_images(
|
|
||||||
trajectory["observation"]["exterior_image_1_left"],
|
|
||||||
trajectory["observation"]["exterior_image_2_left"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
trajectory["observation"]["proprio"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["cartesian_position"],
|
|
||||||
trajectory["observation"]["gripper_position"],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
|
|
||||||
"""
|
|
||||||
wrist_act = velocity_act_to_wrist_frame(
|
|
||||||
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
|
|
||||||
)
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
wrist_act,
|
|
||||||
trajectory["action_dict"]["gripper_position"],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
|
||||||
rand_swap_exterior_images(
|
|
||||||
trajectory["observation"]["exterior_image_1_left"],
|
|
||||||
trajectory["observation"]["exterior_image_2_left"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
trajectory["observation"]["proprio"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["cartesian_position"],
|
|
||||||
trajectory["observation"]["gripper_position"],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
|
||||||
"""
|
|
||||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
|
||||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
dt,
|
|
||||||
dr_,
|
|
||||||
1 - trajectory["action_dict"]["gripper_position"],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["proprio"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["cartesian_position"],
|
|
||||||
trajectory["observation"]["gripper_position"],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def zero_action_filter(traj: Dict) -> bool:
|
|
||||||
"""
|
|
||||||
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
|
||||||
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
|
||||||
"""
|
|
||||||
droid_q01 = tf.convert_to_tensor(
|
|
||||||
[
|
|
||||||
-0.7776297926902771,
|
|
||||||
-0.5803514122962952,
|
|
||||||
-0.5795090794563293,
|
|
||||||
-0.6464047729969025,
|
|
||||||
-0.7041108310222626,
|
|
||||||
-0.8895104378461838,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
droid_q99 = tf.convert_to_tensor(
|
|
||||||
[
|
|
||||||
0.7597932070493698,
|
|
||||||
0.5726242214441299,
|
|
||||||
0.7351000607013702,
|
|
||||||
0.6705610305070877,
|
|
||||||
0.6464948207139969,
|
|
||||||
0.8897542208433151,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
droid_norm_0_act = (
|
|
||||||
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)
|
|
||||||
@@ -1,859 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
NOTE(YL): Adapted from:
|
|
||||||
OpenVLA: https://github.com/openvla/openvla
|
|
||||||
Octo: https://github.com/octo-models/octo
|
|
||||||
|
|
||||||
transforms.py
|
|
||||||
|
|
||||||
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
|
|
||||||
|
|
||||||
Transforms adopt the following structure:
|
|
||||||
Input: Dictionary of *batched* features (i.e., has leading time dimension)
|
|
||||||
Output: Dictionary `step` =>> {
|
|
||||||
"observation": {
|
|
||||||
<image_keys, depth_image_keys>
|
|
||||||
State (in chosen state representation)
|
|
||||||
},
|
|
||||||
"action": Action (in chosen action representation),
|
|
||||||
"language_instruction": str
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import (
|
|
||||||
binarize_gripper_actions,
|
|
||||||
invert_gripper_actions,
|
|
||||||
rel2abs_gripper_actions,
|
|
||||||
relabel_bridge_actions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def droid_baseact_transform_fn():
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform
|
|
||||||
|
|
||||||
return droid_baseact_transform
|
|
||||||
|
|
||||||
|
|
||||||
def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Applies to version of Bridge V2 in Open X-Embodiment mixture.
|
|
||||||
|
|
||||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
|
||||||
"""
|
|
||||||
for key in trajectory:
|
|
||||||
if key == "traj_metadata":
|
|
||||||
continue
|
|
||||||
elif key in ["observation", "action"]:
|
|
||||||
for key2 in trajectory[key]:
|
|
||||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
|
||||||
else:
|
|
||||||
trajectory[key] = trajectory[key][1:]
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
trajectory = relabel_bridge_actions(trajectory)
|
|
||||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Applies to original version of Bridge V2 from the official project website.
|
|
||||||
|
|
||||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
|
||||||
"""
|
|
||||||
for key in trajectory:
|
|
||||||
if key == "traj_metadata":
|
|
||||||
continue
|
|
||||||
elif key == "observation":
|
|
||||||
for key2 in trajectory[key]:
|
|
||||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
|
||||||
else:
|
|
||||||
trajectory[key] = trajectory[key][1:]
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
[
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
|
||||||
],
|
|
||||||
axis=1,
|
|
||||||
)
|
|
||||||
trajectory = relabel_bridge_actions(trajectory)
|
|
||||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
[
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
|
||||||
],
|
|
||||||
axis=1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# make gripper action absolute action, +1 = open, 0 = close
|
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
|
||||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
gripper_action[:, None],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# make gripper action absolute action, +1 = open, 0 = close
|
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
|
||||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
gripper_action[:, None],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
# decode compressed state
|
|
||||||
eef_value = tf.io.decode_compressed(
|
|
||||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
|
|
||||||
compression_type="ZLIB",
|
|
||||||
)
|
|
||||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
|
||||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
|
||||||
gripper_value = tf.io.decode_compressed(
|
|
||||||
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
|
|
||||||
)
|
|
||||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
|
||||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
|
|
||||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
|
|
||||||
trajectory["action"] = trajectory["action"]["rel_actions_world"]
|
|
||||||
|
|
||||||
# invert gripper action + clip, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
|
||||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
|
|
||||||
:, -1:
|
|
||||||
]
|
|
||||||
|
|
||||||
# make gripper action absolute action, +1 = open, 0 = close
|
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
|
||||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
tf.zeros_like(trajectory["action"]["world_vector"]),
|
|
||||||
gripper_action[:, None],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# invert absolute gripper action, +1 = open, 0 = close
|
|
||||||
gripper_action = invert_gripper_actions(
|
|
||||||
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
gripper_action,
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# make gripper action absolute action, +1 = open, 0 = close
|
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
|
||||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
gripper_action[:, None],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# make gripper action, +1 = open, 0 = close
|
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
|
|
||||||
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
|
|
||||||
gripper_action = invert_gripper_actions(gripper_action)
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
gripper_action,
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
|
|
||||||
|
|
||||||
# make gripper action absolute action, +1 = open, 0 = close
|
|
||||||
gripper_action = trajectory["action"]["gripper_closedness_action"]
|
|
||||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
gripper_action[:, None],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# default to "open" gripper
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"],
|
|
||||||
tf.zeros_like(trajectory["action"]),
|
|
||||||
tf.zeros_like(trajectory["action"]),
|
|
||||||
tf.ones_like(trajectory["action"][:, :1]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# decode language instruction
|
|
||||||
instruction_bytes = trajectory["observation"]["instruction"]
|
|
||||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
|
||||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
|
||||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
|
|
||||||
:, 0
|
|
||||||
]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["world_vector"],
|
|
||||||
trajectory["action"]["rotation_delta"],
|
|
||||||
trajectory["action"]["gripper_closedness_action"][:, None],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :3],
|
|
||||||
tf.zeros_like(trajectory["action"][:, :3]),
|
|
||||||
trajectory["action"][:, -1:],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
|
|
||||||
trajectory["action"] = trajectory["action"][..., :7]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# invert gripper action, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
trajectory["observation"]["eef_state"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["state"][:, :3],
|
|
||||||
trajectory["observation"]["state"][:, 7:10],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# invert gripper action + clip, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
|
|
||||||
trajectory["observation"]["depth_additional_view"] = tf.cast(
|
|
||||||
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
|
|
||||||
)
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
|
|
||||||
|
|
||||||
# clip gripper action, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, -8:-2],
|
|
||||||
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
import tensorflow_graphics.geometry.transformation as tft
|
|
||||||
|
|
||||||
trajectory["observation"]["state"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["state"][:, :7],
|
|
||||||
trajectory["observation"]["state"][:, -1:],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# invert gripper action + clip, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :3],
|
|
||||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
|
||||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = trajectory["action"][..., :-1]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
|
||||||
trajectory["action"] = trajectory["action"][..., :-1]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :3],
|
|
||||||
tf.zeros_like(trajectory["action"][:, :3]),
|
|
||||||
trajectory["action"][:, -1:],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# invert gripper action + clip, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# invert gripper action + clip, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"]["future/xyz_residual"][:, :3],
|
|
||||||
trajectory["action"]["future/axis_angle_residual"][:, :3],
|
|
||||||
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
trajectory["action"] = trajectory["action"][..., :-1]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
trajectory["action"] = trajectory["action"][..., :-1]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = trajectory["action"][..., -7:]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["state"][:, :4],
|
|
||||||
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :4],
|
|
||||||
tf.zeros_like(trajectory["action"][:, :2]),
|
|
||||||
trajectory["action"][:, -1:],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
trajectory["observation"]["state"] = tf.concat((
|
|
||||||
tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32),
|
|
||||||
trajectory["observation"]["pose"],
|
|
||||||
trajectory["observation"]["joint_pos"],),
|
|
||||||
axis=-1,)
|
|
||||||
"""
|
|
||||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
tf.zeros_like(trajectory["action"][:, :1]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["end_effector_pose"][:, :4],
|
|
||||||
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :4],
|
|
||||||
tf.zeros_like(trajectory["action"][:, :2]),
|
|
||||||
trajectory["action"][:, -1:],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# invert gripper action, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = trajectory["action"][..., :-1]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
import tensorflow_graphics.geometry.transformation as tft
|
|
||||||
|
|
||||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :3],
|
|
||||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
|
||||||
trajectory["action"][:, 7:8],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"],
|
|
||||||
tf.zeros_like(trajectory["action"]),
|
|
||||||
tf.zeros_like(trajectory["action"][:, :1]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
|
||||||
|
|
||||||
# invert gripper action + clip, +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :6],
|
|
||||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
|
|
||||||
|
|
||||||
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"],
|
|
||||||
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
import tensorflow_graphics.geometry.transformation as tft
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :3],
|
|
||||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
|
||||||
trajectory["action"][:, -1:],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :3],
|
|
||||||
trajectory["action"][:, -4:],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["eef_state"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["state"][:, :3],
|
|
||||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
|
||||||
trajectory["action"] = trajectory["action"][..., :-1]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
trajectory["observation"]["state"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["position"],
|
|
||||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
|
||||||
trajectory["observation"]["yaw"],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"],
|
|
||||||
tf.zeros_like(trajectory["action"]),
|
|
||||||
tf.zeros_like(trajectory["action"]),
|
|
||||||
tf.zeros_like(trajectory["action"][:, :1]),
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# every input feature is batched, ie has leading batch dimension
|
|
||||||
trajectory["observation"]["proprio"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["observation"]["eef_pose"],
|
|
||||||
trajectory["observation"]["state_gripper_pose"][..., None],
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# every input feature is batched, ie has leading batch dimension
|
|
||||||
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
# gripper action is in -1...1 --> clip to 0...1, flip
|
|
||||||
gripper_action = trajectory["action"][:, -1:]
|
|
||||||
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
|
|
||||||
|
|
||||||
trajectory["action"] = tf.concat(
|
|
||||||
(
|
|
||||||
trajectory["action"][:, :7],
|
|
||||||
gripper_action,
|
|
||||||
),
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
# === Registry ===
|
|
||||||
OPENX_STANDARDIZATION_TRANSFORMS = {
|
|
||||||
"bridge_openx": bridge_openx_dataset_transform,
|
|
||||||
"bridge_orig": bridge_orig_dataset_transform,
|
|
||||||
"bridge_dataset": bridge_orig_dataset_transform,
|
|
||||||
"ppgm": ppgm_dataset_transform,
|
|
||||||
"ppgm_static": ppgm_dataset_transform,
|
|
||||||
"ppgm_wrist": ppgm_dataset_transform,
|
|
||||||
"fractal20220817_data": rt1_dataset_transform,
|
|
||||||
"kuka": kuka_dataset_transform,
|
|
||||||
"taco_play": taco_play_dataset_transform,
|
|
||||||
"jaco_play": jaco_play_dataset_transform,
|
|
||||||
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
|
|
||||||
"roboturk": roboturk_dataset_transform,
|
|
||||||
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
|
|
||||||
"viola": viola_dataset_transform,
|
|
||||||
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
|
|
||||||
"toto": toto_dataset_transform,
|
|
||||||
"language_table": language_table_dataset_transform,
|
|
||||||
"columbia_cairlab_pusht_real": pusht_dataset_transform,
|
|
||||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
|
|
||||||
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
|
|
||||||
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
|
|
||||||
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
|
|
||||||
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
|
|
||||||
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
|
|
||||||
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
|
|
||||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
|
|
||||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
|
|
||||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
|
|
||||||
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
|
|
||||||
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
|
|
||||||
"bc_z": bc_z_dataset_transform,
|
|
||||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
|
|
||||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
|
|
||||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform,
|
|
||||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
|
|
||||||
"robo_net": robo_net_dataset_transform,
|
|
||||||
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
|
|
||||||
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
|
|
||||||
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
|
|
||||||
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
|
|
||||||
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
|
|
||||||
"dlr_sara_pour_converted_externally_to_rlds": identity_transform,
|
|
||||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
|
|
||||||
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
|
|
||||||
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
|
|
||||||
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
|
|
||||||
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
|
|
||||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
|
|
||||||
"uiuc_d3field": uiuc_d3field_dataset_transform,
|
|
||||||
"utaustin_mutex": utaustin_mutex_dataset_transform,
|
|
||||||
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
|
|
||||||
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
|
|
||||||
"cmu_play_fusion": playfusion_dataset_transform,
|
|
||||||
"cmu_stretch": cmu_stretch_dataset_transform,
|
|
||||||
"berkeley_gnm_recon": gnm_dataset_transform,
|
|
||||||
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
|
||||||
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
|
||||||
"droid": droid_baseact_transform_fn(),
|
|
||||||
"droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid
|
|
||||||
"fmb": fmb_transform,
|
|
||||||
"dobbe": dobbe_dataset_transform,
|
|
||||||
"robo_set": robo_set_dataset_transform,
|
|
||||||
"usc_cloth_sim_converted_externally_to_rlds": identity_transform,
|
|
||||||
"plex_robosuite": identity_transform,
|
|
||||||
"conq_hose_manipulation": identity_transform,
|
|
||||||
"io_ai_tech": identity_transform,
|
|
||||||
"spoc": identity_transform,
|
|
||||||
}
|
|
||||||
@@ -14,13 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
For all datasets in the RLDS format.
|
||||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||||
|
|
||||||
|
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
python lerobot/scripts/push_dataset_to_hub.py \
|
python lerobot/scripts/push_dataset_to_hub.py \
|
||||||
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
|
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
|
||||||
--repo-id youliangtan/sampled_bridge_data_v2 \
|
--repo-id your_hub/sampled_bridge_data_v2 \
|
||||||
--raw-format openx_rlds.bridge_orig \
|
--raw-format rlds \
|
||||||
--episodes 3 4 5 8 9
|
--episodes 3 4 5 8 9
|
||||||
|
|
||||||
Exact dataset fps defined in openx/config.py, obtained from:
|
Exact dataset fps defined in openx/config.py, obtained from:
|
||||||
@@ -35,12 +38,10 @@ import tensorflow as tf
|
|||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import yaml
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
concatenate_episodes,
|
concatenate_episodes,
|
||||||
@@ -52,11 +53,6 @@ from lerobot.common.datasets.utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||||
|
|
||||||
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml") as f:
|
|
||||||
_openx_list = yaml.safe_load(f)
|
|
||||||
|
|
||||||
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
|
|
||||||
|
|
||||||
np.set_printoptions(precision=2)
|
np.set_printoptions(precision=2)
|
||||||
|
|
||||||
|
|
||||||
@@ -108,7 +104,6 @@ def load_from_raw(
|
|||||||
video: bool,
|
video: bool,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
encoding: dict | None = None,
|
encoding: dict | None = None,
|
||||||
openx_dataset_name: str | None = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -136,16 +131,17 @@ def load_from_raw(
|
|||||||
# we will apply the standardization transform if the dataset_name is provided
|
# we will apply the standardization transform if the dataset_name is provided
|
||||||
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
||||||
# search for 'image' keys in the observations
|
# search for 'image' keys in the observations
|
||||||
if openx_dataset_name is not None:
|
image_keys = []
|
||||||
print(" - applying standardization transform for dataset: ", openx_dataset_name)
|
state_keys = []
|
||||||
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
|
observation_info = dataset_info.features["steps"]["observation"]
|
||||||
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
|
for key in observation_info:
|
||||||
dataset = dataset.map(transform_fn)
|
# check whether the key is for an image or a vector observation
|
||||||
|
if len(observation_info[key].shape) == 3:
|
||||||
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
|
# only adding uint8 images discards depth images
|
||||||
else:
|
if observation_info[key].dtype == tf.uint8:
|
||||||
obs_keys = dataset_info.features["steps"]["observation"].keys()
|
image_keys.append(key)
|
||||||
image_keys = [key for key in obs_keys if "image" in key]
|
else:
|
||||||
|
state_keys.append(key)
|
||||||
|
|
||||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||||
|
|
||||||
@@ -193,50 +189,31 @@ def load_from_raw(
|
|||||||
|
|
||||||
num_frames = episode["action"].shape[0]
|
num_frames = episode["action"].shape[0]
|
||||||
|
|
||||||
###########################################################
|
|
||||||
# Handle the episodic data
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done[-1] = True
|
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
langs = [] # TODO: might be located in "observation"
|
for key in state_keys:
|
||||||
|
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
|
||||||
|
|
||||||
image_array_dict = {key: [] for key in image_keys}
|
ep_dict["action"] = tf_to_torch(episode["action"])
|
||||||
|
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
|
||||||
# We will create the state observation tensor by stacking the state
|
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
|
||||||
# obs keys defined in the openx/configs.py
|
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
|
||||||
if openx_dataset_name is not None:
|
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
|
||||||
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
|
ep_dict["discount"] = tf_to_torch(episode["discount"])
|
||||||
# stack the state observations, if is None, pad with zeros
|
|
||||||
states = []
|
|
||||||
for key in state_obs_keys:
|
|
||||||
if key in episode["observation"]:
|
|
||||||
states.append(tf_to_torch(episode["observation"][key]))
|
|
||||||
else:
|
|
||||||
states.append(torch.zeros(num_frames, 1)) # pad with zeros
|
|
||||||
states = torch.cat(states, dim=1)
|
|
||||||
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
|
|
||||||
else:
|
|
||||||
states = tf_to_torch(episode["observation"]["state"])
|
|
||||||
|
|
||||||
actions = tf_to_torch(episode["action"])
|
|
||||||
rewards = tf_to_torch(episode["reward"]).float()
|
|
||||||
|
|
||||||
# If lang_key is present, convert the entire tensor at once
|
# If lang_key is present, convert the entire tensor at once
|
||||||
if lang_key is not None:
|
if lang_key is not None:
|
||||||
langs = [str(x) for x in episode[lang_key]]
|
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
|
||||||
|
|
||||||
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
|
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||||
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
|
|
||||||
|
image_array_dict = {key: [] for key in image_keys}
|
||||||
|
|
||||||
for im_key in image_keys:
|
for im_key in image_keys:
|
||||||
imgs = episode["observation"][im_key]
|
imgs = episode["observation"][im_key]
|
||||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
||||||
|
|
||||||
# simple assertions
|
|
||||||
for item in [states, actions, rewards, done]:
|
|
||||||
assert len(item) == num_frames
|
|
||||||
|
|
||||||
###########################################################
|
|
||||||
|
|
||||||
# loop through all cameras
|
# loop through all cameras
|
||||||
for im_key in image_keys:
|
for im_key in image_keys:
|
||||||
img_key = f"observation.images.{im_key}"
|
img_key = f"observation.images.{im_key}"
|
||||||
@@ -262,17 +239,6 @@ def load_from_raw(
|
|||||||
else:
|
else:
|
||||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||||
|
|
||||||
if lang_key is not None:
|
|
||||||
ep_dict["language_instruction"] = langs
|
|
||||||
|
|
||||||
ep_dict["observation.state"] = states
|
|
||||||
ep_dict["action"] = actions
|
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
|
||||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
|
||||||
ep_dict["next.reward"] = rewards
|
|
||||||
ep_dict["next.done"] = done
|
|
||||||
|
|
||||||
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
||||||
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
||||||
)
|
)
|
||||||
@@ -290,30 +256,28 @@ def load_from_raw(
|
|||||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||||
features = {}
|
features = {}
|
||||||
|
|
||||||
keys = [key for key in data_dict if "observation.images." in key]
|
for key in data_dict:
|
||||||
for key in keys:
|
# check if vector state obs
|
||||||
if video:
|
if key.startswith("observation.") and "observation.images." not in key:
|
||||||
features[key] = VideoFrame()
|
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||||
else:
|
# check if image obs
|
||||||
features[key] = Image()
|
elif "observation.images." in key:
|
||||||
|
if video:
|
||||||
|
features[key] = VideoFrame()
|
||||||
|
else:
|
||||||
|
features[key] = Image()
|
||||||
|
|
||||||
features["observation.state"] = Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if "observation.velocity" in data_dict:
|
|
||||||
features["observation.velocity"] = Sequence(
|
|
||||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if "observation.effort" in data_dict:
|
|
||||||
features["observation.effort"] = Sequence(
|
|
||||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
)
|
|
||||||
if "language_instruction" in data_dict:
|
if "language_instruction" in data_dict:
|
||||||
features["language_instruction"] = Value(dtype="string", id=None)
|
features["language_instruction"] = Value(dtype="string", id=None)
|
||||||
|
|
||||||
features["action"] = Sequence(
|
features["action"] = Sequence(
|
||||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
features["is_terminal"] = Value(dtype="bool", id=None)
|
||||||
|
features["is_first"] = Value(dtype="bool", id=None)
|
||||||
|
features["discount"] = Value(dtype="float32", id=None)
|
||||||
|
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
@@ -333,19 +297,8 @@ def from_raw_to_lerobot_format(
|
|||||||
video: bool = True,
|
video: bool = True,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
encoding: dict | None = None,
|
encoding: dict | None = None,
|
||||||
openx_dataset_name: str | None = None,
|
|
||||||
):
|
):
|
||||||
"""This is a test impl for rlds conversion"""
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||||
if openx_dataset_name is None:
|
|
||||||
# set a default rlds frame rate if the dataset is not from openx
|
|
||||||
fps = 30
|
|
||||||
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
|
|
||||||
raise ValueError(
|
|
||||||
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
|
|
||||||
)
|
|
||||||
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
|
|
||||||
|
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
@@ -476,6 +477,8 @@ def create_lerobot_dataset_card(
|
|||||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||||
"""
|
"""
|
||||||
card_tags = ["LeRobot"]
|
card_tags = ["LeRobot"]
|
||||||
|
card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md")
|
||||||
|
|
||||||
if tags:
|
if tags:
|
||||||
card_tags += tags
|
card_tags += tags
|
||||||
if dataset_info:
|
if dataset_info:
|
||||||
@@ -493,8 +496,9 @@ def create_lerobot_dataset_card(
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return DatasetCard.from_template(
|
return DatasetCard.from_template(
|
||||||
card_data=card_data,
|
card_data=card_data,
|
||||||
template_path="./lerobot/common/datasets/card_template.md",
|
template_path=str(card_template_path),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -287,8 +287,11 @@ def control_loop(
|
|||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
frame = {**observation, **action}
|
frame = {**observation, **action}
|
||||||
|
<<<<<<< HEAD
|
||||||
if "next.reward" in events:
|
if "next.reward" in events:
|
||||||
frame["next.reward"] = events["next.reward"]
|
frame["next.reward"] = events["next.reward"]
|
||||||
|
=======
|
||||||
|
>>>>>>> main
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
if display_cameras and not is_headless():
|
||||||
@@ -372,7 +375,11 @@ def sanity_check_dataset_robot_compatibility(
|
|||||||
|
|
||||||
mismatches = []
|
mismatches = []
|
||||||
for field, dataset_value, present_value in fields:
|
for field, dataset_value, present_value in fields:
|
||||||
|
<<<<<<< HEAD
|
||||||
diff = DeepDiff(dataset_value, present_value)
|
diff = DeepDiff(dataset_value, present_value)
|
||||||
|
=======
|
||||||
|
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||||||
|
>>>>>>> main
|
||||||
if diff:
|
if diff:
|
||||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ python lerobot/scripts/control_robot.py teleoperate \
|
|||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root tmp/data \
|
|
||||||
--repo-id $USER/koch_test \
|
--repo-id $USER/koch_test \
|
||||||
--num-episodes 1 \
|
--num-episodes 1 \
|
||||||
--run-compute-stats 0
|
--run-compute-stats 0
|
||||||
@@ -38,7 +37,6 @@ python lerobot/scripts/control_robot.py record \
|
|||||||
- Visualize dataset:
|
- Visualize dataset:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset.py \
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
--root tmp/data \
|
|
||||||
--repo-id $USER/koch_test \
|
--repo-id $USER/koch_test \
|
||||||
--episode-index 0
|
--episode-index 0
|
||||||
```
|
```
|
||||||
@@ -47,7 +45,6 @@ python lerobot/scripts/visualize_dataset.py \
|
|||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py replay \
|
python lerobot/scripts/control_robot.py replay \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root tmp/data \
|
|
||||||
--repo-id $USER/koch_test \
|
--repo-id $USER/koch_test \
|
||||||
--episode 0
|
--episode 0
|
||||||
```
|
```
|
||||||
@@ -57,7 +54,6 @@ python lerobot/scripts/control_robot.py replay \
|
|||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id $USER/koch_pick_place_lego \
|
--repo-id $USER/koch_pick_place_lego \
|
||||||
--num-episodes 50 \
|
--num-episodes 50 \
|
||||||
--warmup-time-s 2 \
|
--warmup-time-s 2 \
|
||||||
@@ -72,12 +68,12 @@ python lerobot/scripts/control_robot.py record \
|
|||||||
- Tap escape key 'esc' to stop the data recording.
|
- Tap escape key 'esc' to stop the data recording.
|
||||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||||
|
|
||||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--resume 1`.
|
||||||
To avoid resuming by deleting the dataset, use `--force-override 1`.
|
If the dataset you want to extend is not on the hub, you also need to add `--local-files-only 1`.
|
||||||
|
|
||||||
- Train on this dataset with the ACT policy:
|
- Train on this dataset with the ACT policy:
|
||||||
```bash
|
```bash
|
||||||
DATA_DIR=data python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=act_koch_real \
|
policy=act_koch_real \
|
||||||
env=koch_real \
|
env=koch_real \
|
||||||
dataset_repo_id=$USER/koch_pick_place_lego \
|
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||||
@@ -88,7 +84,6 @@ DATA_DIR=data python lerobot/scripts/train.py \
|
|||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/control_robot.py record \
|
python lerobot/scripts/control_robot.py record \
|
||||||
--fps 30 \
|
--fps 30 \
|
||||||
--root data \
|
|
||||||
--repo-id $USER/eval_act_koch_real \
|
--repo-id $USER/eval_act_koch_real \
|
||||||
--num-episodes 10 \
|
--num-episodes 10 \
|
||||||
--warmup-time-s 2 \
|
--warmup-time-s 2 \
|
||||||
@@ -191,7 +186,7 @@ def teleoperate(
|
|||||||
@safe_disconnect
|
@safe_disconnect
|
||||||
def record(
|
def record(
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
root: str,
|
root: Path,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
single_task: str,
|
single_task: str,
|
||||||
pretrained_policy_name_or_path: str | None = None,
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
@@ -205,6 +200,10 @@ def record(
|
|||||||
video: bool = True,
|
video: bool = True,
|
||||||
run_compute_stats: bool = True,
|
run_compute_stats: bool = True,
|
||||||
push_to_hub: bool = True,
|
push_to_hub: bool = True,
|
||||||
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
>>>>>>> main
|
||||||
num_image_writer_processes: int = 0,
|
num_image_writer_processes: int = 0,
|
||||||
num_image_writer_threads_per_camera: int = 4,
|
num_image_writer_threads_per_camera: int = 4,
|
||||||
display_cameras: bool = True,
|
display_cameras: bool = True,
|
||||||
@@ -223,6 +222,11 @@ def record(
|
|||||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if single_task:
|
||||||
|
task = single_task
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only single-task recording is supported for now")
|
||||||
|
|
||||||
if single_task:
|
if single_task:
|
||||||
task = single_task
|
task = single_task
|
||||||
else:
|
else:
|
||||||
@@ -262,7 +266,10 @@ def record(
|
|||||||
use_videos=video,
|
use_videos=video,
|
||||||
image_writer_processes=num_image_writer_processes,
|
image_writer_processes=num_image_writer_processes,
|
||||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||||
|
<<<<<<< HEAD
|
||||||
features=extra_features,
|
features=extra_features,
|
||||||
|
=======
|
||||||
|
>>>>>>> main
|
||||||
)
|
)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
@@ -335,7 +342,11 @@ def record(
|
|||||||
dataset.consolidate(run_compute_stats)
|
dataset.consolidate(run_compute_stats)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
|
<<<<<<< HEAD
|
||||||
dataset.push_to_hub()
|
dataset.push_to_hub()
|
||||||
|
=======
|
||||||
|
dataset.push_to_hub(tags=tags)
|
||||||
|
>>>>>>> main
|
||||||
|
|
||||||
log_say("Exiting", play_sounds)
|
log_say("Exiting", play_sounds)
|
||||||
return dataset
|
return dataset
|
||||||
@@ -349,7 +360,11 @@ def replay(
|
|||||||
episode: int,
|
episode: int,
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
play_sounds: bool = True,
|
play_sounds: bool = True,
|
||||||
|
<<<<<<< HEAD
|
||||||
local_files_only: bool = True,
|
local_files_only: bool = True,
|
||||||
|
=======
|
||||||
|
local_files_only: bool = False,
|
||||||
|
>>>>>>> main
|
||||||
):
|
):
|
||||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
@@ -431,8 +446,8 @@ if __name__ == "__main__":
|
|||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="data",
|
default=None,
|
||||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
||||||
)
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
@@ -440,6 +455,12 @@ if __name__ == "__main__":
|
|||||||
default="lerobot/test",
|
default="lerobot/test",
|
||||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
)
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--local-files-only",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||||
|
)
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"--warmup-time-s",
|
"--warmup-time-s",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -498,12 +519,21 @@ if __name__ == "__main__":
|
|||||||
"Not enough threads might cause low camera fps."
|
"Not enough threads might cause low camera fps."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
<<<<<<< HEAD
|
||||||
# parser_record.add_argument(
|
# parser_record.add_argument(
|
||||||
# "--force-override",
|
# "--force-override",
|
||||||
# type=int,
|
# type=int,
|
||||||
# default=0,
|
# default=0,
|
||||||
# help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
# help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
||||||
# )
|
# )
|
||||||
|
=======
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Resume recording on an existing dataset.",
|
||||||
|
)
|
||||||
|
>>>>>>> main
|
||||||
parser_record.add_argument(
|
parser_record.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
"--pretrained-policy-name-or-path",
|
"--pretrained-policy-name-or-path",
|
||||||
@@ -533,8 +563,8 @@ if __name__ == "__main__":
|
|||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="data",
|
default=None,
|
||||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
||||||
)
|
)
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
@@ -542,6 +572,12 @@ if __name__ == "__main__":
|
|||||||
default="lerobot/test",
|
default="lerobot/test",
|
||||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
)
|
)
|
||||||
|
parser_replay.add_argument(
|
||||||
|
"--local-files-only",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||||
|
)
|
||||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
546
lerobot/scripts/control_sim_robot.py
Normal file
546
lerobot/scripts/control_sim_robot.py
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
"""
|
||||||
|
Utilities to control a robot in simulation.
|
||||||
|
|
||||||
|
Useful to record a dataset, replay a recorded episode and record an evaluation dataset.
|
||||||
|
|
||||||
|
Examples of usage:
|
||||||
|
|
||||||
|
|
||||||
|
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency.
|
||||||
|
You can modify this value depending on how fast your simulation can run:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py teleoperate \
|
||||||
|
--fps 30 \
|
||||||
|
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||||
|
--sim-config lerobot/configs/env/your_sim_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
- Record one episode in order to test replay:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_sim_robot.py record \
|
||||||
|
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||||
|
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
||||||
|
--fps 30 \
|
||||||
|
--repo-id $USER/robot_sim_test \
|
||||||
|
--num-episodes 1 \
|
||||||
|
--run-compute-stats 0
|
||||||
|
```
|
||||||
|
|
||||||
|
Enable the --push-to-hub 1 to push the recorded dataset to the huggingface hub.
|
||||||
|
|
||||||
|
- Visualize dataset:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
|
--repo-id $USER/robot_sim_test \
|
||||||
|
--episode-index 0
|
||||||
|
```
|
||||||
|
|
||||||
|
- Replay a sequence of test episodes:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_sim_robot.py replay \
|
||||||
|
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||||
|
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
||||||
|
--fps 30 \
|
||||||
|
--repo-id $USER/robot_sim_test \
|
||||||
|
--episode 0
|
||||||
|
```
|
||||||
|
Note: The seed is saved, therefore, during replay we can load the same environment state as the one during collection.
|
||||||
|
|
||||||
|
- Record a full dataset in order to train a policy,
|
||||||
|
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_sim_robot.py record \
|
||||||
|
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||||
|
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
||||||
|
--fps 30 \
|
||||||
|
--repo-id $USER/robot_sim_test \
|
||||||
|
--num-episodes 50 \
|
||||||
|
--episode-time-s 30 \
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: You can use your keyboard to control data recording flow.
|
||||||
|
- Tap right arrow key '->' to early exit while recording an episode and go to reseting the environment.
|
||||||
|
- Tap right arrow key '->' to early exit while reseting the environment and got to recording the next episode.
|
||||||
|
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
||||||
|
- Tap escape key 'esc' to stop the data recording.
|
||||||
|
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||||
|
|
||||||
|
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
|
init_keyboard_listener,
|
||||||
|
init_policy,
|
||||||
|
is_headless,
|
||||||
|
log_control_info,
|
||||||
|
predict_action,
|
||||||
|
sanity_check_dataset_name,
|
||||||
|
sanity_check_dataset_robot_compatibility,
|
||||||
|
stop_recording,
|
||||||
|
)
|
||||||
|
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||||
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say
|
||||||
|
|
||||||
|
DEFAULT_FEATURES = {
|
||||||
|
"next.reward": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"next.success": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Utilities
|
||||||
|
########################################################################################
|
||||||
|
def none_or_int(value):
|
||||||
|
if value == "None":
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
|
def init_sim_calibration(robot, cfg):
|
||||||
|
# Constants necessary for transforming the joint pos of the real robot to the sim
|
||||||
|
# depending on the robot discription used in that sim.
|
||||||
|
start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
|
||||||
|
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
||||||
|
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
||||||
|
|
||||||
|
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
|
||||||
|
|
||||||
|
|
||||||
|
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
|
||||||
|
"""Counts - starting position -> radians -> align axes -> offset"""
|
||||||
|
return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Control modes
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
|
||||||
|
env = env()
|
||||||
|
env.reset()
|
||||||
|
start_teleop_t = time.perf_counter()
|
||||||
|
while True:
|
||||||
|
leader_pos = robot.leader_arms.main.read("Present_Position")
|
||||||
|
action = process_action_fn(leader_pos)
|
||||||
|
env.step(np.expand_dims(action, 0))
|
||||||
|
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||||
|
print("Teleoperation processes finished.")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def record(
|
||||||
|
env,
|
||||||
|
robot: Robot,
|
||||||
|
process_action_from_leader,
|
||||||
|
root: Path,
|
||||||
|
repo_id: str,
|
||||||
|
task: str,
|
||||||
|
fps: int | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
pretrained_policy_name_or_path: str = None,
|
||||||
|
policy_overrides: bool | None = None,
|
||||||
|
episode_time_s: int = 30,
|
||||||
|
num_episodes: int = 50,
|
||||||
|
video: bool = True,
|
||||||
|
push_to_hub: bool = True,
|
||||||
|
num_image_writer_processes: int = 0,
|
||||||
|
num_image_writer_threads_per_camera: int = 4,
|
||||||
|
display_cameras: bool = False,
|
||||||
|
play_sounds: bool = True,
|
||||||
|
resume: bool = False,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
run_compute_stats: bool = True,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
# Load pretrained policy
|
||||||
|
policy = None
|
||||||
|
if pretrained_policy_name_or_path is not None:
|
||||||
|
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||||
|
|
||||||
|
if fps is None:
|
||||||
|
fps = policy_fps
|
||||||
|
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||||
|
|
||||||
|
if policy is None and process_action_from_leader is None:
|
||||||
|
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||||
|
|
||||||
|
# initialize listener before sim env
|
||||||
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
|
# create sim env
|
||||||
|
env = env()
|
||||||
|
|
||||||
|
# Create empty dataset or load existing saved episodes
|
||||||
|
num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space])
|
||||||
|
|
||||||
|
# get image keys
|
||||||
|
image_keys = [key for key in env.observation_space if "image" in key]
|
||||||
|
state_keys_dict = env_cfg.state_keys
|
||||||
|
|
||||||
|
if resume:
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
repo_id,
|
||||||
|
root=root,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
dataset.start_image_writer(
|
||||||
|
num_processes=num_image_writer_processes,
|
||||||
|
num_threads=num_image_writer_threads_per_camera * num_cameras,
|
||||||
|
)
|
||||||
|
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||||
|
else:
|
||||||
|
features = DEFAULT_FEATURES
|
||||||
|
# add image keys to features
|
||||||
|
for key in image_keys:
|
||||||
|
shape = env.observation_space[key].shape
|
||||||
|
if not key.startswith("observation.image."):
|
||||||
|
key = "observation.image." + key
|
||||||
|
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}
|
||||||
|
|
||||||
|
for key, obs_key in state_keys_dict.items():
|
||||||
|
features[key] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"names": None,
|
||||||
|
"shape": env.observation_space[obs_key].shape,
|
||||||
|
}
|
||||||
|
|
||||||
|
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||||
|
|
||||||
|
# Create empty dataset or load existing saved episodes
|
||||||
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id,
|
||||||
|
fps,
|
||||||
|
root=root,
|
||||||
|
features=features,
|
||||||
|
use_videos=video,
|
||||||
|
image_writer_processes=num_image_writer_processes,
|
||||||
|
image_writer_threads=num_image_writer_threads_per_camera * num_cameras,
|
||||||
|
)
|
||||||
|
|
||||||
|
recorded_episodes = 0
|
||||||
|
while True:
|
||||||
|
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||||
|
|
||||||
|
if events is None:
|
||||||
|
events = {"exit_early": False}
|
||||||
|
|
||||||
|
if episode_time_s is None:
|
||||||
|
episode_time_s = float("inf")
|
||||||
|
|
||||||
|
timestamp = 0
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
|
seed = np.random.randint(0, 1e5)
|
||||||
|
observation, info = env.reset(seed=seed)
|
||||||
|
|
||||||
|
while timestamp < episode_time_s:
|
||||||
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
|
if policy is not None:
|
||||||
|
action = predict_action(observation, policy, device, use_amp)
|
||||||
|
else:
|
||||||
|
leader_pos = robot.leader_arms.main.read("Present_Position")
|
||||||
|
action = process_action_from_leader(leader_pos)
|
||||||
|
|
||||||
|
observation, reward, terminated, _, info = env.step(action)
|
||||||
|
|
||||||
|
success = info.get("is_success", False)
|
||||||
|
env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps)
|
||||||
|
|
||||||
|
frame = {
|
||||||
|
"action": torch.from_numpy(action),
|
||||||
|
"next.reward": reward,
|
||||||
|
"next.success": success,
|
||||||
|
"seed": seed,
|
||||||
|
"timestamp": env_timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in image_keys:
|
||||||
|
if not key.startswith("observation.image"):
|
||||||
|
frame["observation.image." + key] = observation[key]
|
||||||
|
else:
|
||||||
|
frame[key] = observation[key]
|
||||||
|
|
||||||
|
for key, obs_key in state_keys_dict.items():
|
||||||
|
frame[key] = torch.from_numpy(observation[obs_key])
|
||||||
|
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
if display_cameras and not is_headless():
|
||||||
|
for key in image_keys:
|
||||||
|
cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR))
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
if fps is not None:
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
log_control_info(robot, dt_s, fps=fps)
|
||||||
|
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
if events["exit_early"] or terminated:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if events["rerecord_episode"]:
|
||||||
|
log_say("Re-record episode", play_sounds)
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset.save_episode(task=task)
|
||||||
|
recorded_episodes += 1
|
||||||
|
|
||||||
|
if events["stop_recording"] or recorded_episodes >= num_episodes:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logging.info("Waiting for a few seconds before starting next episode recording...")
|
||||||
|
busy_wait(3)
|
||||||
|
|
||||||
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
|
stop_recording(robot, listener, display_cameras)
|
||||||
|
|
||||||
|
if run_compute_stats:
|
||||||
|
logging.info("Computing dataset statistics")
|
||||||
|
dataset.consolidate(run_compute_stats)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
dataset.push_to_hub(tags=tags)
|
||||||
|
|
||||||
|
log_say("Exiting", play_sounds)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def replay(
|
||||||
|
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
|
||||||
|
):
|
||||||
|
env = env()
|
||||||
|
|
||||||
|
local_dir = Path(root) / repo_id
|
||||||
|
if not local_dir.exists():
|
||||||
|
raise ValueError(local_dir)
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||||
|
items = dataset.hf_dataset.select_columns("action")
|
||||||
|
seeds = dataset.hf_dataset.select_columns("seed")["seed"]
|
||||||
|
|
||||||
|
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||||
|
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||||
|
env.reset(seed=seeds[from_idx].item())
|
||||||
|
logging.info("Replaying episode")
|
||||||
|
log_say("Replaying episode", play_sounds=True)
|
||||||
|
for idx in range(from_idx, to_idx):
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
action = items[idx]["action"]
|
||||||
|
env.step(action.unsqueeze(0).numpy())
|
||||||
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||||
|
|
||||||
|
# Set common options for all the subparsers
|
||||||
|
base_parser = argparse.ArgumentParser(add_help=False)
|
||||||
|
base_parser.add_argument(
|
||||||
|
"--robot-path",
|
||||||
|
type=str,
|
||||||
|
default="lerobot/configs/robot/koch.yaml",
|
||||||
|
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_parser.add_argument(
|
||||||
|
"--sim-config",
|
||||||
|
help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser_record = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||||
|
|
||||||
|
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--root",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
default="lerobot/test",
|
||||||
|
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--episode-time-s",
|
||||||
|
type=int,
|
||||||
|
default=60,
|
||||||
|
help="Number of seconds for data recording for each episode.",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="A description of the task preformed during recording that can be used as a language instruction.",
|
||||||
|
)
|
||||||
|
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--run-compute-stats",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Upload dataset to Hugging Face hub.",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--tags",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="Add tags to your dataset on the hub.",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--num-image-writer-processes",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help=(
|
||||||
|
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
|
||||||
|
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
|
||||||
|
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
|
||||||
|
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--num-image-writer-threads-per-camera",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help=(
|
||||||
|
"Number of threads writing the frames as png images on disk, per camera. "
|
||||||
|
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||||
|
"Not enough threads might cause low camera fps."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--display-cameras",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Visualize image observations with opencv.",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Resume recording on an existing dataset.",
|
||||||
|
)
|
||||||
|
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||||
|
parser_replay.add_argument(
|
||||||
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
|
)
|
||||||
|
parser_replay.add_argument(
|
||||||
|
"--root",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Root directory where the dataset will be stored locally (e.g. 'data/hf_username/dataset_name'). By default, stored in cache folder.",
|
||||||
|
)
|
||||||
|
parser_replay.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
default="lerobot/test",
|
||||||
|
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||||
|
)
|
||||||
|
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
control_mode = args.mode
|
||||||
|
robot_path = args.robot_path
|
||||||
|
env_config_path = args.sim_config
|
||||||
|
kwargs = vars(args)
|
||||||
|
del kwargs["mode"]
|
||||||
|
del kwargs["robot_path"]
|
||||||
|
del kwargs["sim_config"]
|
||||||
|
|
||||||
|
# make gym env
|
||||||
|
env_cfg = init_hydra_config(env_config_path)
|
||||||
|
importlib.import_module(f"gym_{env_cfg.env.name}")
|
||||||
|
|
||||||
|
def env_constructor():
|
||||||
|
return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)
|
||||||
|
|
||||||
|
robot = None
|
||||||
|
process_leader_actions_fn = None
|
||||||
|
|
||||||
|
if control_mode in ["teleoperate", "record"]:
|
||||||
|
# make robot
|
||||||
|
robot_overrides = ["~cameras", "~follower_arms"]
|
||||||
|
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||||
|
robot = make_robot(robot_cfg)
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
calib_kwgs = init_sim_calibration(robot, env_cfg.calibration)
|
||||||
|
|
||||||
|
def process_leader_actions_fn(action):
|
||||||
|
return real_positions_to_sim(action, **calib_kwgs)
|
||||||
|
|
||||||
|
robot.leader_arms.main.calibration = None
|
||||||
|
|
||||||
|
if control_mode == "teleoperate":
|
||||||
|
teleoperate(env_constructor, robot, process_leader_actions_fn)
|
||||||
|
|
||||||
|
elif control_mode == "record":
|
||||||
|
record(env_constructor, robot, process_leader_actions_fn, **kwargs)
|
||||||
|
|
||||||
|
elif control_mode == "replay":
|
||||||
|
replay(env_constructor, **kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay."
|
||||||
|
)
|
||||||
|
|
||||||
|
if robot and robot.is_connected:
|
||||||
|
# Disconnect manually to avoid a "Core dump" during process
|
||||||
|
# termination due to camera threads not properly exiting.
|
||||||
|
robot.disconnect()
|
||||||
@@ -66,7 +66,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
|||||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "aloha_hdf5":
|
elif raw_format == "aloha_hdf5":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||||
elif "openx_rlds" in raw_format:
|
elif raw_format in ["rlds", "openx"]:
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
||||||
elif raw_format == "dora_parquet":
|
elif raw_format == "dora_parquet":
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||||
@@ -204,24 +204,14 @@ def push_dataset_to_hub(
|
|||||||
# convert dataset from original raw format to LeRobot format
|
# convert dataset from original raw format to LeRobot format
|
||||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||||
|
|
||||||
fmt_kwgs = {
|
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||||
"raw_dir": raw_dir,
|
raw_dir,
|
||||||
"videos_dir": videos_dir,
|
videos_dir,
|
||||||
"fps": fps,
|
fps,
|
||||||
"video": video,
|
video,
|
||||||
"episodes": episodes,
|
episodes,
|
||||||
"encoding": encoding,
|
encoding,
|
||||||
}
|
)
|
||||||
|
|
||||||
if "openx_rlds." in raw_format:
|
|
||||||
# Support for official OXE dataset name inside `raw_format`.
|
|
||||||
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
|
|
||||||
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
|
|
||||||
_, openx_dataset_name = raw_format.split(".")
|
|
||||||
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
|
|
||||||
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
|
|
||||||
|
|
||||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
|
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
@@ -290,7 +280,7 @@ def main():
|
|||||||
"--raw-format",
|
"--raw-format",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).",
|
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
|
|||||||
@@ -207,11 +207,17 @@ def main():
|
|||||||
required=True,
|
required=True,
|
||||||
help="Episode to visualize.",
|
help="Episode to visualize.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local-files-only",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=None,
|
||||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir",
|
"--output-dir",
|
||||||
@@ -269,9 +275,16 @@ def main():
|
|||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
repo_id = kwargs.pop("repo_id")
|
repo_id = kwargs.pop("repo_id")
|
||||||
root = kwargs.pop("root")
|
root = kwargs.pop("root")
|
||||||
|
<<<<<<< HEAD
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
logging.info("Loading dataset")
|
||||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
||||||
|
=======
|
||||||
|
local_files_only = kwargs.pop("local_files_only")
|
||||||
|
|
||||||
|
logging.info("Loading dataset")
|
||||||
|
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||||
|
>>>>>>> main
|
||||||
|
|
||||||
visualize_dataset(dataset, **vars(args))
|
visualize_dataset(dataset, **vars(args))
|
||||||
|
|
||||||
|
|||||||
@@ -234,6 +234,12 @@ def main():
|
|||||||
required=True,
|
required=True,
|
||||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local-files-only",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
@@ -282,7 +288,13 @@ def main():
|
|||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
repo_id = kwargs.pop("repo_id")
|
repo_id = kwargs.pop("repo_id")
|
||||||
root = kwargs.pop("root")
|
root = kwargs.pop("root")
|
||||||
|
<<<<<<< HEAD
|
||||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
||||||
|
=======
|
||||||
|
local_files_only = kwargs.pop("local_files_only")
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||||
|
>>>>>>> main
|
||||||
visualize_dataset_html(dataset, **kwargs)
|
visualize_dataset_html(dataset, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||||||
assert dataset.meta.total_episodes == 2
|
assert dataset.meta.total_episodes == 2
|
||||||
assert len(dataset) == 2
|
assert len(dataset) == 2
|
||||||
|
|
||||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True)
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): rethink this design
|
# TODO(rcadene, aliberts): rethink this design
|
||||||
if robot_type == "aloha":
|
if robot_type == "aloha":
|
||||||
@@ -295,24 +295,12 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
|||||||
dataset = record(**record_kwargs)
|
dataset = record(**record_kwargs)
|
||||||
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
||||||
|
|
||||||
# init_dataset_return_value = {}
|
|
||||||
|
|
||||||
# def wrapped_init_dataset(*args, **kwargs):
|
|
||||||
# nonlocal init_dataset_return_value
|
|
||||||
# init_dataset_return_value = init_dataset(*args, **kwargs)
|
|
||||||
# return init_dataset_return_value
|
|
||||||
|
|
||||||
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
|
||||||
|
|
||||||
with pytest.raises(FileExistsError):
|
with pytest.raises(FileExistsError):
|
||||||
# Dataset already exists, but resume=False by default
|
# Dataset already exists, but resume=False by default
|
||||||
record(**record_kwargs)
|
record(**record_kwargs)
|
||||||
|
|
||||||
dataset = record(**record_kwargs, resume=True)
|
dataset = record(**record_kwargs, resume=True)
|
||||||
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
||||||
# assert (
|
|
||||||
# init_dataset_return_value["num_episodes"] == 2
|
|
||||||
# ), "`init_dataset` should load the previous episode"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
|||||||
include a report on what changed and how that affected the outputs.
|
include a report on what changed and how that affected the outputs.
|
||||||
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
||||||
add the policies you want to update the test artifacts for.
|
add the policies you want to update the test artifacts for.
|
||||||
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||||
should be updated.
|
should be updated.
|
||||||
4. Check that this test now passes.
|
4. Check that this test now passes.
|
||||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ we skip them for now in our CI.
|
|||||||
|
|
||||||
Example to run backward compatiblity tests locally:
|
Example to run backward compatiblity tests locally:
|
||||||
```
|
```
|
||||||
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -330,7 +330,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
"Not compatible with our CI since it downloads raw datasets. Run with `python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||||
)
|
)
|
||||||
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||||
_, dataset_id = repo_id.split("/")
|
_, dataset_id = repo_id.split("/")
|
||||||
|
|||||||
Reference in New Issue
Block a user