Compare commits

..

8 Commits

Author SHA1 Message Date
Remi Cadene
49ae3e19e1 Add clone, delete, WIP on remove_episode, drop_frame 2024-12-27 17:49:47 +01:00
Remi Cadene
ebe0bfad77 WIP 2024-12-11 09:09:14 -08:00
Remi Cadene
c6e9a3dc24 nit 2024-12-03 17:24:55 +01:00
Remi Cadene
afbd42d082 Add manage_dataset 2024-12-03 17:16:47 +01:00
Michel Aractingi
8e7d6970ea Control simulated robot with real leader (#514)
Co-authored-by: Remi <remi.cadene@huggingface.co>
2024-12-03 12:20:05 +01:00
Remi
286bca37cc Fix missing local_files_only in record/replay (#540)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
2024-12-03 10:53:21 +01:00
Michel Aractingi
a2c181992a Refactor OpenX (#505) 2024-12-03 00:51:55 +01:00
Simon Alibert
32eb0cec8f Dataset v2.0 (#461)
Co-authored-by: Remi <remi.cadene@huggingface.co>
2024-11-29 19:04:00 +01:00
31 changed files with 926 additions and 2271 deletions

View File

@@ -21,7 +21,7 @@ Provide a simple way for the reviewer to try out your changes.
Examples:
```bash
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
pytest -sx tests/test_stuff.py::test_something
```
```bash
python lerobot/scripts/train.py --some.option=true

View File

@@ -7,10 +7,8 @@ on:
schedule:
- cron: "0 2 * * *"
env:
DATA_DIR: tests/data
# env:
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
jobs:
run_all_tests_cpu:
name: CPU
@@ -30,13 +28,9 @@ jobs:
working-directory: /lerobot
steps:
- name: Tests
env:
DATA_DIR: tests/data
run: pytest -v --cov=./lerobot --disable-warnings tests
- name: Tests end-to-end
env:
DATA_DIR: tests/data
run: make test-end-to-end

View File

@@ -29,7 +29,6 @@ jobs:
name: Pytest
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
@@ -70,7 +69,6 @@ jobs:
name: Pytest (minimal install)
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
@@ -104,39 +102,38 @@ jobs:
&& rm -rf tests/outputs outputs
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
end-to-end:
name: End-to-end
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
# end-to-end:
# name: End-to-end
# runs-on: ubuntu-latest
# env:
# MUJOCO_GL: egl
# steps:
# - uses: actions/checkout@v4
# with:
# lfs: true # Ensure LFS files are pulled
- name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio
run: |
sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
# - name: Install apt dependencies
# # portaudio19-dev is needed to install pyaudio
# run: |
# sudo apt-get update && \
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
- name: Install poetry
run: |
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# - name: Install poetry
# run: |
# pipx install poetry && poetry config virtualenvs.in-project true
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "poetry"
# - name: Set up Python 3.10
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "poetry"
- name: Install poetry dependencies
run: |
poetry install --all-extras
# - name: Install poetry dependencies
# run: |
# poetry install --all-extras
- name: Test end-to-end
run: |
make test-end-to-end \
&& rm -rf outputs
# - name: Test end-to-end
# run: |
# make test-end-to-end \
# && rm -rf outputs

View File

@@ -267,7 +267,7 @@ We use `pytest` in order to run the tests. From the root of the
repository, here's how to run tests with `pytest` for the library:
```bash
DATA_DIR="tests/data" python -m pytest -sv ./tests
python -m pytest -sv ./tests
```

View File

@@ -153,10 +153,12 @@ python lerobot/scripts/visualize_dataset.py \
--episode-index 0
```
or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--local-files-only 1 \
--episode-index 0
```
@@ -208,12 +210,10 @@ dataset attributes:
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
- hf_dataset stored using Hugging Face datasets library serialization to parquet
- videos are stored in mp4 format to save space or png files
- episode_data_index saved using `safetensor` tensor serialization format
- stats saved using `safetensor` tensor serialization format
- info are saved using JSON
- videos are stored in mp4 format to save space
- metadata are stored in plain json/jsonl files
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
### Evaluate a pretrained policy

View File

@@ -192,7 +192,6 @@ Record 2 episodes and upload your dataset to the hub:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/so100.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/so100_test \
--tags so100 tutorial \
--warmup-time-s 5 \
@@ -212,7 +211,6 @@ echo ${HF_USER}/so100_test
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/so100_test
```
@@ -220,10 +218,9 @@ python lerobot/scripts/visualize_dataset_html.py \
Now try to replay the first episode on your robot:
```bash
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/so100.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/so100_test \
--episode 0
```
@@ -232,7 +229,7 @@ DATA_DIR=data python lerobot/scripts/control_robot.py replay \
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/so100_test \
policy=act_so100_real \
env=so100_real \
@@ -248,7 +245,6 @@ Let's explain it:
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
@@ -259,7 +255,6 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/so100.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_act_so100_test \
--tags so100 tutorial eval \
--warmup-time-s 5 \

View File

@@ -192,7 +192,6 @@ Record 2 episodes and upload your dataset to the hub:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/moss.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/moss_test \
--tags moss tutorial \
--warmup-time-s 5 \
@@ -212,7 +211,6 @@ echo ${HF_USER}/moss_test
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/moss_test
```
@@ -220,10 +218,9 @@ python lerobot/scripts/visualize_dataset_html.py \
Now try to replay the first episode on your robot:
```bash
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/moss.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/moss_test \
--episode 0
```
@@ -232,7 +229,7 @@ DATA_DIR=data python lerobot/scripts/control_robot.py replay \
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/moss_test \
policy=act_moss_real \
env=moss_real \
@@ -248,7 +245,6 @@ Let's explain it:
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
@@ -259,7 +255,6 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/moss.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_act_moss_test \
--tags moss tutorial eval \
--warmup-time-s 5 \

View File

@@ -778,7 +778,6 @@ Now run this to record 2 episodes:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/koch.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/koch_test \
--tags tutorial \
--warmup-time-s 5 \
@@ -787,7 +786,7 @@ python lerobot/scripts/control_robot.py record \
--num-episodes 2
```
This will write your dataset locally to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
@@ -840,7 +839,6 @@ In the coming months, we plan to release a foundational model for robotics. We a
You can visualize your dataset by running:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/koch_test
```
@@ -858,7 +856,6 @@ To replay the first episode of the dataset you just recorded, run the following
python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/koch.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/koch_test \
--episode 0
```
@@ -871,7 +868,7 @@ Your robot should replicate movements similar to those you recorded. For example
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/koch_test \
policy=act_koch_real \
env=koch_real \
@@ -918,7 +915,6 @@ env:
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
@@ -991,7 +987,6 @@ To this end, you can use the `record` function from [`lerobot/scripts/control_ro
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/koch.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_koch_test \
--tags tutorial eval \
--warmup-time-s 5 \
@@ -1010,7 +1005,6 @@ As you can see, it's almost the same command as previously used to record your t
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
```bash
python lerobot/scripts/visualize_dataset.py \
--root data \
--repo-id ${HF_USER}/eval_koch_test
```

View File

@@ -128,7 +128,6 @@ Record one episode:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/stretch.yaml \
--fps 20 \
--root data \
--repo-id ${HF_USER}/stretch_test \
--tags stretch tutorial \
--warmup-time-s 3 \
@@ -146,7 +145,6 @@ Now try to replay this episode (make sure the robot's initial position is the sa
python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/stretch.yaml \
--fps 20 \
--root data \
--repo-id ${HF_USER}/stretch_test \
--episode 0
```

View File

@@ -84,7 +84,6 @@ python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/aloha_test \
--tags aloha tutorial \
--warmup-time-s 5 \
@@ -104,7 +103,6 @@ echo ${HF_USER}/aloha_test
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/aloha_test
```
@@ -119,7 +117,6 @@ python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/aloha_test \
--episode 0
```
@@ -128,7 +125,7 @@ python lerobot/scripts/control_robot.py replay \
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/aloha_test \
policy=act_aloha_real \
env=aloha_real \
@@ -144,7 +141,6 @@ Let's explain it:
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
4. We provided `device=cuda` since we are training on a Nvidia GPU.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
@@ -156,7 +152,6 @@ python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_act_aloha_test \
--tags aloha tutorial eval \
--warmup-time-s 5 \

View File

@@ -63,7 +63,7 @@ def build_features(mode: str) -> dict:
return features
def load_raw_dataset(zarr_path: Path, load_images: bool = True):
def load_raw_dataset(zarr_path: Path):
try:
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,

View File

@@ -298,7 +298,7 @@ class LeRobotDatasetMetadata:
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
)
elif robot_type is None or features is None:
elif features is None:
raise ValueError(
"Dataset features must either come from a Robot or explicitly passed upon creation."
)
@@ -873,7 +873,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
def consolidate(
self,
run_compute_stats: bool = True,
keep_image_files: bool = False,
batch_size: int = 8,
num_workers: int = 8,
) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
@@ -896,7 +902,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if run_compute_stats:
self.stop_image_writer()
# TODO(aliberts): refactor stats in save_episodes
self.meta.stats = compute_stats(self)
self.meta.stats = compute_stats(self, batch_size=batch_size, num_workers=num_workers)
serialized_stats = serialize_dict(self.meta.stats)
write_json(serialized_stats, self.root / STATS_PATH)
self.consolidated = True
@@ -958,6 +964,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj
def clone(self, new_repo_id: str, new_root: str | Path | None = None) -> "LeRobotDataset":
return LeRobotDataset.create(
repo_id=new_repo_id,
fps=self.fps,
root=new_root,
robot=self.robot,
robot_type=self.robot_type,
features=self.features,
use_videos=self.use_videos,
tolerance_s=self.tolerance_s,
image_writer_processes=self.image_writer_processes,
image_writer_threads=self.image_writer_threads,
video_backend=self.video_backend,
)
def delete(self):
"""Delete the dataset locally. If it was push to hub, you can still access it by downloading it again."""
shutil.rmtree(self.root)
def remove_episode(self, episode: int | list[int]):
if isinstance(episode, int):
episode = [episode]
for ep in episode:
self.meta.info
def drop_frame(self, episode_range: dict[int, tuple[int]]):
pass
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.

View File

@@ -1,639 +0,0 @@
OPENX_DATASET_CONFIGS:
fractal20220817_data:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- base_pose_tool_reached
- gripper_closed
fps: 3
kuka:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- clip_function_input/base_pose_tool_reached
- gripper_closed
fps: 10
bridge_openx:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- EEF_state
- gripper_state
fps: 5
taco_play:
image_obs_keys:
- rgb_static
- rgb_gripper
depth_obs_keys:
- depth_static
- depth_gripper
state_obs_keys:
- state_eef
- state_gripper
fps: 15
jaco_play:
image_obs_keys:
- image
- image_wrist
depth_obs_keys:
- null
state_obs_keys:
- state_eef
- state_gripper
fps: 10
berkeley_cable_routing:
image_obs_keys:
- image
- top_image
- wrist45_image
- wrist225_image
depth_obs_keys:
- null
state_obs_keys:
- robot_state
fps: 10
roboturk:
image_obs_keys:
- front_rgb
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
nyu_door_opening_surprising_effectiveness:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 3
viola:
image_obs_keys:
- agentview_rgb
- eye_in_hand_rgb
depth_obs_keys:
- null
state_obs_keys:
- joint_states
- gripper_states
fps: 20
berkeley_autolab_ur5:
image_obs_keys:
- image
- hand_image
depth_obs_keys:
- image_with_depth
state_obs_keys:
- state
fps: 5
toto:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 30
language_table:
image_obs_keys:
- rgb
depth_obs_keys:
- null
state_obs_keys:
- effector_translation
fps: 10
columbia_cairlab_pusht_real:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- robot_state
fps: 10
stanford_kuka_multimodal_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- depth_image
state_obs_keys:
- ee_position
- ee_orientation
fps: 20
nyu_rot_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 3
io_ai_tech:
image_obs_keys:
- image
- image_fisheye
- image_left_side
- image_right_side
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 3
stanford_hydra_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
austin_buds_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
nyu_franka_play_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- image_additional_view
depth_obs_keys:
- depth
- depth_additional_view
state_obs_keys:
- eef_state
fps: 3
maniskill_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- depth
- wrist_depth
state_obs_keys:
- tcp_pose
- gripper_state
fps: 20
furniture_bench_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
cmu_franka_exploration_dataset_converted_externally_to_rlds:
image_obs_keys:
- highres_image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
ucsd_kitchen_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
fps: 2
ucsd_pick_and_place_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 3
spoc:
image_obs_keys:
- image
- image_manipulation
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 3
austin_sailor_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
austin_sirius_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
bc_z:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- present/xyz
- present/axis_angle
- present/sensed_close
fps: 10
utokyo_pr2_opening_fridge_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
utokyo_xarm_pick_and_place_converted_externally_to_rlds:
image_obs_keys:
- image
- image2
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- end_effector_pose
fps: 10
utokyo_xarm_bimanual_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- pose_r
fps: 10
robo_net:
image_obs_keys:
- image
- image1
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 1
robo_set:
image_obs_keys:
- image_left
- image_right
- image_wrist
depth_obs_keys:
- null
state_obs_keys:
- state
- state_velocity
fps: 5
berkeley_mvp_converted_externally_to_rlds:
image_obs_keys:
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- gripper
- pose
- joint_pos
fps: 5
berkeley_rpt_converted_externally_to_rlds:
image_obs_keys:
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- joint_pos
- gripper
fps: 30
kaist_nonprehensile_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
stanford_mask_vit_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
tokyo_u_lsmo_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
dlr_sara_pour_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
dlr_sara_grid_clamp_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
dlr_edan_shared_control_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 5
asu_table_top_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 12.5
stanford_robocook_converted_externally_to_rlds:
image_obs_keys:
- image_1
- image_2
depth_obs_keys:
- depth_1
- depth_2
state_obs_keys:
- eef_state
- gripper_state
fps: 5
imperialcollege_sawyer_wrist_cam:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
iamlab_cmu_pickup_insert_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
- gripper_state
fps: 20
uiuc_d3field:
image_obs_keys:
- image_1
- image_2
depth_obs_keys:
- depth_1
- depth_2
state_obs_keys:
- null
fps: 1
utaustin_mutex:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
berkeley_fanuc_manipulation:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
- gripper_state
fps: 10
cmu_playing_with_food:
image_obs_keys:
- image
- finger_vision_1
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
cmu_play_fusion:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 5
cmu_stretch:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
berkeley_gnm_recon:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 3
berkeley_gnm_cory_hall:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 5
berkeley_gnm_sac_son:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 10
droid:
image_obs_keys:
- exterior_image_1_left
- exterior_image_2_left
- wrist_image_left
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 15
droid_100:
image_obs_keys:
- exterior_image_1_left
- exterior_image_2_left
- wrist_image_left
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 15
fmb:
image_obs_keys:
- image_side_1
- image_side_2
- image_wrist_1
- image_wrist_2
depth_obs_keys:
- image_side_1_depth
- image_side_2_depth
- image_wrist_1_depth
- image_wrist_2_depth
state_obs_keys:
- proprio
fps: 10
dobbe:
image_obs_keys:
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 3.75
usc_cloth_sim_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
plex_robosuite:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
conq_hose_manipulation:
image_obs_keys:
- frontleft_fisheye_image
- frontright_fisheye_image
- hand_color_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 30

View File

@@ -1,106 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the Licens e.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
data_utils.py
Additional utils for data processing.
"""
from typing import Any, Dict, List
import tensorflow as tf
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts gripper actions from continuous to binary values (0 and 1).
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
values based on the state that is reached _after_ those intermediate values.
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
chunk of intermediate values as the last action in the trajectory.
The `scan_fn` implements the following logic:
new_actions = np.empty_like(actions)
carry = actions[-1]
for i in reversed(range(actions.shape[0])):
if in_between_mask[i]:
carry = carry
else:
carry = float(open_mask[i])
new_actions[i] = carry
"""
open_mask, closed_mask = actions > 0.95, actions < 0.05
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
is_open_float = tf.cast(open_mask, tf.float32)
def scan_fn(carry, i):
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
return 1 - actions
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
"""
# Note =>> -1 for closing, 1 for opening, 0 for no change
opening_mask, closing_mask = actions < -0.1, actions > 0.1
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
def scan_fn(carry, i):
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
# If no relative grasp, assumes open for whole trajectory
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
start = tf.cond(start == 0, lambda: 1, lambda: start)
# Note =>> -1 for closed, 1 for open
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
return new_actions
# === Bridge-V2 =>> Dataset-Specific Transform ===
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
return traj_truncated
# === RLDS Dataset Initialization Utilities ===
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
print("\n######################################################################################")
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
pad = 80 - len(dataset_kwargs["name"])
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
print("######################################################################################\n")

View File

@@ -1,200 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
OpenVLA: https://github.com/openvla/openvla
Episode transforms for DROID dataset.
"""
from typing import Any, Dict
import tensorflow as tf
import tensorflow_graphics.geometry.transformation as tfg
def rmat_to_euler(rot_mat):
return tfg.euler.from_rotation_matrix(rot_mat)
def euler_to_rmat(euler):
return tfg.rotation_matrix_3d.from_euler(euler)
def invert_rmat(rot_mat):
return tfg.rotation_matrix_3d.inverse(rot_mat)
def rotmat_to_rot6d(mat):
"""
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
Args:
mat: rotation matrix
Returns: 6d vector (first two rows of rotation matrix)
"""
r6 = mat[..., :2, :]
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
return r6_flat
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
"""
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
Args:
velocity: 6d velocity action (3 x translation, 3 x rotation)
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
"""
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
r_frame_inv = invert_rmat(r_frame)
# world to wrist: dT_pi = R^-1 dT_rbt
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
# world to wrist: dR_pi = R^-1 dR_rbt R
dr_ = euler_to_rmat(velocity[:, 3:6])
dr_ = r_frame_inv @ (dr_ @ r_frame)
dr_r6 = rotmat_to_rot6d(dr_)
return tf.concat([vel_t, dr_r6], axis=-1)
def rand_swap_exterior_images(img1, img2):
"""
Randomly swaps the two exterior images (for training with single exterior input).
"""
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dr_,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
"""
wrist_act = velocity_act_to_wrist_frame(
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
)
trajectory["action"] = tf.concat(
(
wrist_act,
trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dr_,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def zero_action_filter(traj: Dict) -> bool:
"""
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
"""
droid_q01 = tf.convert_to_tensor(
[
-0.7776297926902771,
-0.5803514122962952,
-0.5795090794563293,
-0.6464047729969025,
-0.7041108310222626,
-0.8895104378461838,
]
)
droid_q99 = tf.convert_to_tensor(
[
0.7597932070493698,
0.5726242214441299,
0.7351000607013702,
0.6705610305070877,
0.6464948207139969,
0.8897542208433151,
]
)
droid_norm_0_act = (
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
)
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)

View File

@@ -1,859 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
OpenVLA: https://github.com/openvla/openvla
Octo: https://github.com/octo-models/octo
transforms.py
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
Transforms adopt the following structure:
Input: Dictionary of *batched* features (i.e., has leading time dimension)
Output: Dictionary `step` =>> {
"observation": {
<image_keys, depth_image_keys>
State (in chosen state representation)
},
"action": Action (in chosen action representation),
"language_instruction": str
}
"""
from typing import Any, Dict
import tensorflow as tf
from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import (
binarize_gripper_actions,
invert_gripper_actions,
rel2abs_gripper_actions,
relabel_bridge_actions,
)
def droid_baseact_transform_fn():
from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform
return droid_baseact_transform
def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to version of Bridge V2 in Open X-Embodiment mixture.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key in ["observation", "action"]:
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to original version of Bridge V2 from the official project website.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key == "observation":
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
return trajectory
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
# decode compressed state
eef_value = tf.io.decode_compressed(
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
compression_type="ZLIB",
)
eef_value = tf.io.decode_raw(eef_value, tf.float32)
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
gripper_value = tf.io.decode_compressed(
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
)
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
trajectory["action"] = trajectory["action"]["rel_actions_world"]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
:, -1:
]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
tf.zeros_like(trajectory["action"]["world_vector"]),
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert absolute gripper action, +1 = open, 0 = close
gripper_action = invert_gripper_actions(
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"]
return trajectory
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
gripper_action = invert_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# default to "open" gripper
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.ones_like(trajectory["action"][:, :1]),
),
axis=-1,
)
# decode language instruction
instruction_bytes = trajectory["observation"]["instruction"]
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
:, 0
]
return trajectory
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
trajectory["action"]["gripper_closedness_action"][:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
trajectory["action"] = trajectory["action"][..., :7]
return trajectory
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
trajectory["observation"]["state"][:, 7:10],
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
return trajectory
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
return trajectory
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
trajectory["observation"]["depth_additional_view"] = tf.cast(
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
)
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
# clip gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, -8:-2],
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
),
axis=-1,
)
return trajectory
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
return trajectory
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["state"][:, :7],
trajectory["observation"]["state"][:, -1:],
),
axis=-1,
)
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["future/xyz_residual"][:, :3],
trajectory["action"]["future/axis_angle_residual"][:, :3],
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., -7:]
return trajectory
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :4],
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
trajectory["observation"]["state"] = tf.concat((
tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32),
trajectory["observation"]["pose"],
trajectory["observation"]["joint_pos"],),
axis=-1,)
"""
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
return trajectory
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
return trajectory
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["end_effector_pose"][:, :4],
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
return trajectory
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
return trajectory
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, 7:8],
),
axis=-1,
)
return trajectory
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"],
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
),
axis=-1,
)
return trajectory
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
trajectory["action"][:, -4:],
),
axis=-1,
)
return trajectory
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["position"],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
trajectory["observation"]["yaw"],
),
axis=-1,
)
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["eef_pose"],
trajectory["observation"]["state_gripper_pose"][..., None],
),
axis=-1,
)
return trajectory
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
return trajectory
def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# gripper action is in -1...1 --> clip to 0...1, flip
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :7],
gripper_action,
),
axis=-1,
)
return trajectory
def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
# === Registry ===
OPENX_STANDARDIZATION_TRANSFORMS = {
"bridge_openx": bridge_openx_dataset_transform,
"bridge_orig": bridge_orig_dataset_transform,
"bridge_dataset": bridge_orig_dataset_transform,
"ppgm": ppgm_dataset_transform,
"ppgm_static": ppgm_dataset_transform,
"ppgm_wrist": ppgm_dataset_transform,
"fractal20220817_data": rt1_dataset_transform,
"kuka": kuka_dataset_transform,
"taco_play": taco_play_dataset_transform,
"jaco_play": jaco_play_dataset_transform,
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
"roboturk": roboturk_dataset_transform,
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
"viola": viola_dataset_transform,
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
"toto": toto_dataset_transform,
"language_table": language_table_dataset_transform,
"columbia_cairlab_pusht_real": pusht_dataset_transform,
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
"bc_z": bc_z_dataset_transform,
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform,
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
"robo_net": robo_net_dataset_transform,
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
"dlr_sara_pour_converted_externally_to_rlds": identity_transform,
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
"uiuc_d3field": uiuc_d3field_dataset_transform,
"utaustin_mutex": utaustin_mutex_dataset_transform,
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
"cmu_play_fusion": playfusion_dataset_transform,
"cmu_stretch": cmu_stretch_dataset_transform,
"berkeley_gnm_recon": gnm_dataset_transform,
"berkeley_gnm_cory_hall": gnm_dataset_transform,
"berkeley_gnm_sac_son": gnm_dataset_transform,
"droid": droid_baseact_transform_fn(),
"droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid
"fmb": fmb_transform,
"dobbe": dobbe_dataset_transform,
"robo_set": robo_set_dataset_transform,
"usc_cloth_sim_converted_externally_to_rlds": identity_transform,
"plex_robosuite": identity_transform,
"conq_hose_manipulation": identity_transform,
"io_ai_tech": identity_transform,
"spoc": identity_transform,
}

View File

@@ -14,13 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
Example:
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
--repo-id youliangtan/sampled_bridge_data_v2 \
--raw-format openx_rlds.bridge_orig \
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
--repo-id your_hub/sampled_bridge_data_v2 \
--raw-format rlds \
--episodes 3 4 5 8 9
Exact dataset fps defined in openx/config.py, obtained from:
@@ -35,12 +38,10 @@ import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
import yaml
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
@@ -52,11 +53,6 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml") as f:
_openx_list = yaml.safe_load(f)
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
np.set_printoptions(precision=2)
@@ -108,7 +104,6 @@ def load_from_raw(
video: bool,
episodes: list[int] | None = None,
encoding: dict | None = None,
openx_dataset_name: str | None = None,
):
"""
Args:
@@ -136,16 +131,17 @@ def load_from_raw(
# we will apply the standardization transform if the dataset_name is provided
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
# search for 'image' keys in the observations
if openx_dataset_name is not None:
print(" - applying standardization transform for dataset: ", openx_dataset_name)
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
dataset = dataset.map(transform_fn)
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
else:
obs_keys = dataset_info.features["steps"]["observation"].keys()
image_keys = [key for key in obs_keys if "image" in key]
image_keys = []
state_keys = []
observation_info = dataset_info.features["steps"]["observation"]
for key in observation_info:
# check whether the key is for an image or a vector observation
if len(observation_info[key].shape) == 3:
# only adding uint8 images discards depth images
if observation_info[key].dtype == tf.uint8:
image_keys.append(key)
else:
state_keys.append(key)
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
@@ -193,50 +189,31 @@ def load_from_raw(
num_frames = episode["action"].shape[0]
###########################################################
# Handle the episodic data
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
ep_dict = {}
langs = [] # TODO: might be located in "observation"
for key in state_keys:
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
image_array_dict = {key: [] for key in image_keys}
# We will create the state observation tensor by stacking the state
# obs keys defined in the openx/configs.py
if openx_dataset_name is not None:
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
# stack the state observations, if is None, pad with zeros
states = []
for key in state_obs_keys:
if key in episode["observation"]:
states.append(tf_to_torch(episode["observation"][key]))
else:
states.append(torch.zeros(num_frames, 1)) # pad with zeros
states = torch.cat(states, dim=1)
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
else:
states = tf_to_torch(episode["observation"]["state"])
actions = tf_to_torch(episode["action"])
rewards = tf_to_torch(episode["reward"]).float()
ep_dict["action"] = tf_to_torch(episode["action"])
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
ep_dict["discount"] = tf_to_torch(episode["discount"])
# If lang_key is present, convert the entire tensor at once
if lang_key is not None:
langs = [str(x) for x in episode[lang_key]]
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
image_array_dict = {key: [] for key in image_keys}
for im_key in image_keys:
imgs = episode["observation"][im_key]
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
# simple assertions
for item in [states, actions, rewards, done]:
assert len(item) == num_frames
###########################################################
# loop through all cameras
for im_key in image_keys:
img_key = f"observation.images.{im_key}"
@@ -262,17 +239,6 @@ def load_from_raw(
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
if lang_key is not None:
ep_dict["language_instruction"] = langs
ep_dict["observation.state"] = states
ep_dict["action"] = actions
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["next.reward"] = rewards
ep_dict["next.done"] = done
path_ep_dict = tmp_ep_dicts_dir.joinpath(
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
)
@@ -290,30 +256,28 @@ def load_from_raw(
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
for key in data_dict:
# check if vector state obs
if key.startswith("observation.") and "observation.images." not in key:
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
# check if image obs
elif "observation.images." in key:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
if "language_instruction" in data_dict:
features["language_instruction"] = Value(dtype="string", id=None)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["is_terminal"] = Value(dtype="bool", id=None)
features["is_first"] = Value(dtype="bool", id=None)
features["discount"] = Value(dtype="float32", id=None)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
@@ -333,19 +297,8 @@ def from_raw_to_lerobot_format(
video: bool = True,
episodes: list[int] | None = None,
encoding: dict | None = None,
openx_dataset_name: str | None = None,
):
"""This is a test impl for rlds conversion"""
if openx_dataset_name is None:
# set a default rlds frame rate if the dataset is not from openx
fps = 30
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
raise ValueError(
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
)
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class HILSerlConfig:
pass

View File

@@ -1,30 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -1,23 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class SACConfig:
discount = 0.99

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from huggingface_hub import PyTorchModelHubMixin
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
super().__init__()
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.critic_ensemble = ...
self.critic_target = ...
self.actor_network = ...
self.temperature = ...
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
queues are populated during rollout of the policy, they contain the n latest observations and actions
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=1),
}
if self._use_image:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch['observations'])###
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
observation_batch =
next_obaservation_batch =
action_batch =
reward_batch =
dones_batch =
# perform image augmentation
# reward bias
# from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
next_actions = ..
# 2- compute q targets
q_targets = self.target_qs(next_obaservation_batch, next_actions)
# critics subsample size
min_q = q_targets.min(dim=0)
# backup entropy
td_target = reward_batch + self.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_ensemble(observation_batch, action_batch)
# 4- Calculate loss
critics_loss = F.mse_loss(q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])) # dones masks
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observation_batch)
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observation_batch, actions, return_type="mean")
actor_loss = -(q_preds - temperature * log_probs).mean()
# calculate temperature loss
# 1- calculate entropy
entropy = -log_probs.mean()
temperature_loss = temperature * (entropy - self.target_entropy).mean()
loss = critics_loss + actor_loss + temperature_loss
return {
"Q_value_loss": critics_loss.item(),
"pi_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"entropy": entropy.item(),
"loss": loss,
}
def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)

View File

@@ -353,7 +353,7 @@ def sanity_check_dataset_robot_compatibility(
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value)
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")

View File

@@ -29,7 +29,6 @@ python lerobot/scripts/control_robot.py teleoperate \
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--num-episodes 1 \
--run-compute-stats 0
@@ -38,7 +37,6 @@ python lerobot/scripts/control_robot.py record \
- Visualize dataset:
```bash
python lerobot/scripts/visualize_dataset.py \
--root tmp/data \
--repo-id $USER/koch_test \
--episode-index 0
```
@@ -47,7 +45,6 @@ python lerobot/scripts/visualize_dataset.py \
```bash
python lerobot/scripts/control_robot.py replay \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--episode 0
```
@@ -57,7 +54,6 @@ python lerobot/scripts/control_robot.py replay \
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root data \
--repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \
--warmup-time-s 2 \
@@ -72,12 +68,12 @@ python lerobot/scripts/control_robot.py record \
- Tap escape key 'esc' to stop the data recording.
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
To avoid resuming by deleting the dataset, use `--force-override 1`.
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--resume 1`.
If the dataset you want to extend is not on the hub, you also need to add `--local-files-only 1`.
- Train on this dataset with the ACT policy:
```bash
DATA_DIR=data python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
policy=act_koch_real \
env=koch_real \
dataset_repo_id=$USER/koch_pick_place_lego \
@@ -88,7 +84,6 @@ DATA_DIR=data python lerobot/scripts/train.py \
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--root data \
--repo-id $USER/eval_act_koch_real \
--num-episodes 10 \
--warmup-time-s 2 \
@@ -191,7 +186,7 @@ def teleoperate(
@safe_disconnect
def record(
robot: Robot,
root: str,
root: Path,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
@@ -204,6 +199,7 @@ def record(
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
tags: list[str] | None = None,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
@@ -331,7 +327,7 @@ def record(
dataset.consolidate(run_compute_stats)
if push_to_hub:
dataset.push_to_hub()
dataset.push_to_hub(tags=tags)
log_say("Exiting", play_sounds)
return dataset
@@ -345,7 +341,7 @@ def replay(
episode: int,
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = True,
local_files_only: bool = False,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
@@ -427,8 +423,8 @@ if __name__ == "__main__":
parser_record.add_argument(
"--root",
type=Path,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
default=None,
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
)
parser_record.add_argument(
"--repo-id",
@@ -436,6 +432,12 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser_record.add_argument(
"--warmup-time-s",
type=int,
@@ -495,10 +497,10 @@ if __name__ == "__main__":
),
)
parser_record.add_argument(
"--force-override",
"--resume",
type=int,
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"-p",
@@ -523,8 +525,8 @@ if __name__ == "__main__":
parser_replay.add_argument(
"--root",
type=Path,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
default=None,
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
)
parser_replay.add_argument(
"--repo-id",
@@ -532,6 +534,12 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
args = parser.parse_args()

View File

@@ -0,0 +1,546 @@
"""
Utilities to control a robot in simulation.
Useful to record a dataset, replay a recorded episode and record an evaluation dataset.
Examples of usage:
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency.
You can modify this value depending on how fast your simulation can run:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--fps 30 \
--robot-path lerobot/configs/robot/your_robot_config.yaml \
--sim-config lerobot/configs/env/your_sim_config.yaml
```
- Record one episode in order to test replay:
```bash
python lerobot/scripts/control_sim_robot.py record \
--robot-path lerobot/configs/robot/your_robot_config.yaml \
--sim-config lerobot/configs/env/your_sim_config.yaml \
--fps 30 \
--repo-id $USER/robot_sim_test \
--num-episodes 1 \
--run-compute-stats 0
```
Enable the --push-to-hub 1 to push the recorded dataset to the huggingface hub.
- Visualize dataset:
```bash
python lerobot/scripts/visualize_dataset.py \
--repo-id $USER/robot_sim_test \
--episode-index 0
```
- Replay a sequence of test episodes:
```bash
python lerobot/scripts/control_sim_robot.py replay \
--robot-path lerobot/configs/robot/your_robot_config.yaml \
--sim-config lerobot/configs/env/your_sim_config.yaml \
--fps 30 \
--repo-id $USER/robot_sim_test \
--episode 0
```
Note: The seed is saved, therefore, during replay we can load the same environment state as the one during collection.
- Record a full dataset in order to train a policy,
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
```bash
python lerobot/scripts/control_sim_robot.py record \
--robot-path lerobot/configs/robot/your_robot_config.yaml \
--sim-config lerobot/configs/env/your_sim_config.yaml \
--fps 30 \
--repo-id $USER/robot_sim_test \
--num-episodes 50 \
--episode-time-s 30 \
```
**NOTE**: You can use your keyboard to control data recording flow.
- Tap right arrow key '->' to early exit while recording an episode and go to reseting the environment.
- Tap right arrow key '->' to early exit while reseting the environment and got to recording the next episode.
- Tap left arrow key '<-' to early exit and re-record the current episode.
- Tap escape key 'esc' to stop the data recording.
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
"""
import argparse
import importlib
import logging
import time
from pathlib import Path
import cv2
import gymnasium as gym
import numpy as np
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.robot_devices.control_utils import (
init_keyboard_listener,
init_policy,
is_headless,
log_control_info,
predict_action,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say
DEFAULT_FEATURES = {
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"seed": {
"dtype": "int64",
"shape": (1,),
"names": None,
},
"timestamp": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
}
########################################################################################
# Utilities
########################################################################################
def none_or_int(value):
if value == "None":
return None
return int(value)
def init_sim_calibration(robot, cfg):
# Constants necessary for transforming the joint pos of the real robot to the sim
# depending on the robot discription used in that sim.
start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
"""Counts - starting position -> radians -> align axes -> offset"""
return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets
########################################################################################
# Control modes
########################################################################################
def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
env = env()
env.reset()
start_teleop_t = time.perf_counter()
while True:
leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0))
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
print("Teleoperation processes finished.")
break
def record(
env,
robot: Robot,
process_action_from_leader,
root: Path,
repo_id: str,
task: str,
fps: int | None = None,
tags: list[str] | None = None,
pretrained_policy_name_or_path: str = None,
policy_overrides: bool | None = None,
episode_time_s: int = 30,
num_episodes: int = 50,
video: bool = True,
push_to_hub: bool = True,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = False,
play_sounds: bool = True,
resume: bool = False,
local_files_only: bool = False,
run_compute_stats: bool = True,
) -> LeRobotDataset:
# Load pretrained policy
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
if policy is None and process_action_from_leader is None:
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
# initialize listener before sim env
listener, events = init_keyboard_listener()
# create sim env
env = env()
# Create empty dataset or load existing saved episodes
num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space])
# get image keys
image_keys = [key for key in env.observation_space if "image" in key]
state_keys_dict = env_cfg.state_keys
if resume:
dataset = LeRobotDataset(
repo_id,
root=root,
local_files_only=local_files_only,
)
dataset.start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * num_cameras,
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
else:
features = DEFAULT_FEATURES
# add image keys to features
for key in image_keys:
shape = env.observation_space[key].shape
if not key.startswith("observation.image."):
key = "observation.image." + key
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}
for key, obs_key in state_keys_dict.items():
features[key] = {
"dtype": "float32",
"names": None,
"shape": env.observation_space[obs_key].shape,
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
features=features,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * num_cameras,
)
recorded_episodes = 0
while True:
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
if events is None:
events = {"exit_early": False}
if episode_time_s is None:
episode_time_s = float("inf")
timestamp = 0
start_episode_t = time.perf_counter()
seed = np.random.randint(0, 1e5)
observation, info = env.reset(seed=seed)
while timestamp < episode_time_s:
start_loop_t = time.perf_counter()
if policy is not None:
action = predict_action(observation, policy, device, use_amp)
else:
leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_from_leader(leader_pos)
observation, reward, terminated, _, info = env.step(action)
success = info.get("is_success", False)
env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps)
frame = {
"action": torch.from_numpy(action),
"next.reward": reward,
"next.success": success,
"seed": seed,
"timestamp": env_timestamp,
}
for key in image_keys:
if not key.startswith("observation.image"):
frame["observation.image." + key] = observation[key]
else:
frame[key] = observation[key]
for key, obs_key in state_keys_dict.items():
frame[key] = torch.from_numpy(observation[obs_key])
dataset.add_frame(frame)
if display_cameras and not is_headless():
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"] or terminated:
events["exit_early"] = False
break
if events["rerecord_episode"]:
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode(task=task)
recorded_episodes += 1
if events["stop_recording"] or recorded_episodes >= num_episodes:
break
else:
logging.info("Waiting for a few seconds before starting next episode recording...")
busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
if run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats)
if push_to_hub:
dataset.push_to_hub(tags=tags)
log_say("Exiting", play_sounds)
return dataset
def replay(
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
):
env = env()
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
items = dataset.hf_dataset.select_columns("action")
seeds = dataset.hf_dataset.select_columns("seed")["seed"]
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
env.reset(seed=seeds[from_idx].item())
logging.info("Replaying episode")
log_say("Replaying episode", play_sounds=True)
for idx in range(from_idx, to_idx):
start_episode_t = time.perf_counter()
action = items[idx]["action"]
env.step(action.unsqueeze(0).numpy())
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / fps - dt_s)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
base_parser.add_argument(
"--sim-config",
help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
)
parser_record = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_record.add_argument(
"--root",
type=Path,
default=None,
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_record.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--episode-time-s",
type=int,
default=60,
help="Number of seconds for data recording for each episode.",
)
parser_record.add_argument(
"--task",
type=str,
required=True,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--run-compute-stats",
type=int,
default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
)
parser_record.add_argument(
"--push-to-hub",
type=int,
default=1,
help="Upload dataset to Hugging Face hub.",
)
parser_record.add_argument(
"--tags",
type=str,
nargs="*",
help="Add tags to your dataset on the hub.",
)
parser_record.add_argument(
"--num-image-writer-processes",
type=int,
default=0,
help=(
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
),
)
parser_record.add_argument(
"--num-image-writer-threads-per-camera",
type=int,
default=4,
help=(
"Number of threads writing the frames as png images on disk, per camera. "
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
"Not enough threads might cause low camera fps."
),
)
parser_record.add_argument(
"--display-cameras",
type=int,
default=0,
help="Visualize image observations with opencv.",
)
parser_record.add_argument(
"--resume",
type=int,
default=0,
help="Resume recording on an existing dataset.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_replay.add_argument(
"--root",
type=Path,
default=None,
help="Root directory where the dataset will be stored locally (e.g. 'data/hf_username/dataset_name'). By default, stored in cache folder.",
)
parser_replay.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
args = parser.parse_args()
init_logging()
control_mode = args.mode
robot_path = args.robot_path
env_config_path = args.sim_config
kwargs = vars(args)
del kwargs["mode"]
del kwargs["robot_path"]
del kwargs["sim_config"]
# make gym env
env_cfg = init_hydra_config(env_config_path)
importlib.import_module(f"gym_{env_cfg.env.name}")
def env_constructor():
return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)
robot = None
process_leader_actions_fn = None
if control_mode in ["teleoperate", "record"]:
# make robot
robot_overrides = ["~cameras", "~follower_arms"]
robot_cfg = init_hydra_config(robot_path, robot_overrides)
robot = make_robot(robot_cfg)
robot.connect()
calib_kwgs = init_sim_calibration(robot, env_cfg.calibration)
def process_leader_actions_fn(action):
return real_positions_to_sim(action, **calib_kwgs)
robot.leader_arms.main.calibration = None
if control_mode == "teleoperate":
teleoperate(env_constructor, robot, process_leader_actions_fn)
elif control_mode == "record":
record(env_constructor, robot, process_leader_actions_fn, **kwargs)
elif control_mode == "replay":
replay(env_constructor, **kwargs)
else:
raise ValueError(
f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay."
)
if robot and robot.is_connected:
# Disconnect manually to avoid a "Core dump" during process
# termination due to camera threads not properly exiting.
robot.disconnect()

View File

@@ -0,0 +1,188 @@
"""
Utilities to manage a dataset.
Examples of usage:
- Consolidate a dataset, by encoding images into videos and computing statistics:
```bash
python lerobot/scripts/manage_dataset.py consolidate \
--repo-id $USER/koch_test
```
- Consolidate a dataset which is not uploaded on the hub yet:
```bash
python lerobot/scripts/manage_dataset.py consolidate \
--repo-id $USER/koch_test \
--local-files-only 1
```
- Upload a dataset on the hub:
```bash
python lerobot/scripts/manage_dataset.py push_to_hub \
--repo-id $USER/koch_test
```
"""
import argparse
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def parse_episode_range_string(ep_range_str):
parts = ep_range_str.split("-")
if len(parts) != 3:
raise ValueError(
f"Invalid episode range string '{ep_range_str}'. Expected format: 'EP-FROM-TO', e.g., '1-5-10'."
)
ep, start, end = parts
return int(ep), int(start), int(end)
def parse_episode_range_strings(ep_range_strings):
ep_ranges = {}
for ep_range_str in ep_range_strings:
ep, start, end = parse_episode_range_string(ep_range_str)
if ep not in ep_ranges:
ep_ranges[ep] = []
ep_ranges[ep].append((start, end))
return ep_ranges
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory where the dataset is stored (e.g. 'dataset/path').",
)
base_parser.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
base_parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
############################################################################
# consolidate
parser_conso = subparsers.add_parser("consolidate", parents=[base_parser])
parser_conso.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size loaded by DataLoader for computing the dataset statistics.",
)
parser_conso.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of processes of Dataloader for computing the dataset statistics.",
)
############################################################################
# push_to_hub
parser_push = subparsers.add_parser("push_to_hub", parents=[base_parser])
parser_push.add_argument(
"--tags",
type=str,
nargs="*",
default=None,
help="Optional additional tags to categorize the dataset on the Hugging Face Hub. Use space-separated values (e.g. 'so100 indoor'). The tag 'LeRobot' will always be added.",
)
parser_push.add_argument(
"--license",
type=str,
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser_push.add_argument(
"--private",
type=int,
default=0,
help="Create a private dataset repository on the Hugging Face Hub. Push publicly by default.",
)
############################################################################
# clone
parser_clone = subparsers.add_parser("clone", parents=[base_parser])
parser_clone.add_argument(
"--root",
type=Path,
default=None,
help="New root directory where the dataset is stored (e.g. 'dataset/path').",
)
parser_clone.add_argument(
"--new-repo-id",
type=str,
help="New dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
############################################################################
# delete
parser_del = subparsers.add_parser("delete", parents=[base_parser])
############################################################################
# remove_episode
parser_rm_ep = subparsers.add_parser("remove_episode", parents=[base_parser])
parser_rm_ep.add_argument(
"--episode",
type=int,
nargs="*",
help="List of one or several episodes to be removed from the dataset locally.",
)
############################################################################
# drop_frame
parser_drop_frame = subparsers.add_parser("drop_frame", parents=[base_parser])
parser_rm_ep.add_argument(
"--episode-range",
type=str,
nargs="*",
help="List of one or several frame ranges per episode to be removed from the dataset locally. For instance, using `--episode-frame-range 0-0-10 3-5-20` will remove from episode 0, the frames from indices 0 to 10 excluded, and from episode 3 the frames from indices 5 to 20.",
)
args = parser.parse_args()
kwargs = vars(args)
mode = kwargs.pop("mode")
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = LeRobotDataset(
repo_id=repo_id,
root=root,
local_files_only=local_files_only,
)
if mode == "consolidate":
dataset.consolidate(**kwargs)
elif mode == "push_to_hub":
private = kwargs.pop("private") == 1
dataset.push_to_hub(private=private, **kwargs)
elif mode == "clone":
dataset.clone(**kwargs)
elif mode == "delete":
dataset.delete(**kwargs)
elif mode == "remove_episode":
dataset.remove_episode(**kwargs)
elif mode == "drop_frame":
ep_range = parse_episode_range_strings(kwargs.pop("episode_range"))
dataset.drop_frame(episode_range=ep_range, **kwargs)

View File

@@ -66,7 +66,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif "openx_rlds" in raw_format:
elif raw_format in ["rlds", "openx"]:
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
@@ -204,24 +204,14 @@ def push_dataset_to_hub(
# convert dataset from original raw format to LeRobot format
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
fmt_kwgs = {
"raw_dir": raw_dir,
"videos_dir": videos_dir,
"fps": fps,
"video": video,
"episodes": episodes,
"encoding": encoding,
}
if "openx_rlds." in raw_format:
# Support for official OXE dataset name inside `raw_format`.
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
_, openx_dataset_name = raw_format.split(".")
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir,
videos_dir,
fps,
video,
episodes,
encoding,
)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
@@ -290,7 +280,7 @@ def main():
"--raw-format",
type=str,
required=True,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).",
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
)
parser.add_argument(
"--repo-id",

View File

@@ -207,11 +207,17 @@ def main():
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--output-dir",
@@ -269,9 +275,10 @@ def main():
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
visualize_dataset(dataset, **vars(args))

View File

@@ -234,6 +234,12 @@ def main():
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
@@ -282,7 +288,9 @@ def main():
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
local_files_only = kwargs.pop("local_files_only")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
visualize_dataset_html(dataset, **kwargs)

View File

@@ -158,7 +158,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True)
# TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha":
@@ -295,24 +295,12 @@ def test_resume_record(tmpdir, request, robot_type, mock):
dataset = record(**record_kwargs)
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
# init_dataset_return_value = {}
# def wrapped_init_dataset(*args, **kwargs):
# nonlocal init_dataset_return_value
# init_dataset_return_value = init_dataset(*args, **kwargs)
# return init_dataset_return_value
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
with pytest.raises(FileExistsError):
# Dataset already exists, but resume=False by default
record(**record_kwargs)
dataset = record(**record_kwargs, resume=True)
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
# assert (
# init_dataset_return_value["num_episodes"] == 2
# ), "`init_dataset` should load the previous episode"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])

View File

@@ -383,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.

View File

@@ -5,7 +5,7 @@ we skip them for now in our CI.
Example to run backward compatiblity tests locally:
```
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
```
"""
@@ -330,7 +330,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
],
)
@pytest.mark.skip(
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
"Not compatible with our CI since it downloads raw datasets. Run with `python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")