Compare commits

..

18 Commits

Author SHA1 Message Date
Thomas Wolf
b670d3c43e Update lerobot/scripts/push_dataset_to_hub.py
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-29 15:30:39 +02:00
Thomas Wolf
efd3357124 Update lerobot/scripts/push_dataset_to_hub.py
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-29 15:30:33 +02:00
Thomas Wolf
785db44bd5 Update lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-29 15:30:21 +02:00
Thomas Wolf
e2f690e779 proposal for a more general Dora env 2024-05-29 15:29:41 +02:00
Thomas Wolf
68a680a9eb make aloha dora more flexible for A koch arm 2024-05-29 11:40:02 +02:00
Thomas Wolf
ce5329cf44 push_to_hub less hardcoded 2024-05-29 11:39:25 +02:00
Remi Cadene
f409bee6b1 WIP 2024-05-24 09:37:11 +00:00
Remi Cadene
c91ececc75 WIP 2024-05-24 09:06:29 +00:00
Simon Alibert
8a10e8442b Add boilerplate code
adding dora node logic

Add some documentation

refactor
2024-05-23 12:59:39 +00:00
Remi Cadene
10c5151fc6 remove hardcoding 2024-05-23 09:14:17 +00:00
Remi Cadene
a7d24f6dc2 works 2024-05-23 07:32:37 +00:00
Remi Cadene
642f1d0328 Add act_real_world.yaml 2024-05-22 15:15:36 +00:00
Remi Cadene
4843988d81 fix 2024-05-22 15:15:13 +00:00
Remi Cadene
772927616a fix 2024-05-22 09:12:34 +00:00
Remi Cadene
d52c6037e8 fix 2024-05-22 09:03:43 +00:00
Remi Cadene
b0cb342795 WIP Add aloha_dora_format 2024-05-21 21:16:31 +00:00
Remi Cadene
8460ea6f83 rename + format 2024-05-20 16:37:21 +00:00
haixuantao
9a3b0b738a Adding dora-record script 2024-05-20 16:34:21 +00:00
98 changed files with 1207 additions and 8266 deletions

View File

@@ -10,6 +10,7 @@ on:
env:
PYTHON_VERSION: "3.10"
# CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs:
latest-cpu:
@@ -34,8 +35,6 @@ jobs:
- name: Check out code
uses: actions/checkout@v4
with:
lfs: true
- name: Login to DockerHub
uses: docker/login-action@v3
@@ -52,6 +51,30 @@ jobs:
tags: huggingface/lerobot-cpu
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
# - name: Post to a Slack channel
# id: slack
# #uses: slackapi/slack-github-action@v1.25.0
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
# with:
# # Slack channel id, channel name, or user id to post message.
# # See also: https://api.slack.com/methods/chat.postMessage#channels
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
# # For posting a rich message using Block Kit
# payload: |
# {
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
# "blocks": [
# {
# "type": "section",
# "text": {
# "type": "mrkdwn",
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
# }
# }
# ]
# }
# env:
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
latest-cuda:
name: GPU
@@ -74,8 +97,6 @@ jobs:
- name: Check out code
uses: actions/checkout@v4
with:
lfs: true
- name: Login to DockerHub
uses: docker/login-action@v3
@@ -92,40 +113,27 @@ jobs:
tags: huggingface/lerobot-gpu
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
latest-cuda-dev:
name: GPU Dev
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo df -h
# sudo ls -l /usr/local/lib/
# sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Check out code
uses: actions/checkout@v4
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU dev
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/lerobot-gpu-dev/Dockerfile
push: true
tags: huggingface/lerobot-gpu:dev
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
# - name: Post to a Slack channel
# id: slack
# #uses: slackapi/slack-github-action@v1.25.0
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
# with:
# # Slack channel id, channel name, or user id to post message.
# # See also: https://api.slack.com/methods/chat.postMessage#channels
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
# # For posting a rich message using Block Kit
# payload: |
# {
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
# "blocks": [
# {
# "type": "section",
# "text": {
# "type": "mrkdwn",
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
# }
# }
# ]
# }
# env:
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@@ -70,8 +70,6 @@ jobs:
# files: ./coverage.xml
# verbose: true
- name: Tests end-to-end
env:
DEVICE: cuda
run: make test-end-to-end
# - name: Generate Report

31
.gitignore vendored
View File

@@ -2,16 +2,11 @@
logs
tmp
wandb
# Data
data
outputs
# Apple
.DS_Store
# VS Code
.vscode
rl
.DS_Store
# HPC
nautilus/*.yaml
@@ -95,7 +90,6 @@ instance/
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
@@ -108,6 +102,13 @@ ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
@@ -118,14 +119,6 @@ celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
@@ -143,9 +136,3 @@ dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

View File

@@ -10,7 +10,6 @@ endif
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
DEVICE ?= cpu
build-cpu:
docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
@@ -19,29 +18,25 @@ build-gpu:
docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
${MAKE} test-tdmpc-ete-train
${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval
test-act-ete-train:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=$(DEVICE) \
training.save_checkpoint=true \
device=cpu \
training.save_model=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
@@ -50,67 +45,35 @@ test-act-ete-train:
test-act-ete-eval:
python lerobot/scripts/eval.py \
-p tests/outputs/act/checkpoints/000002/pretrained_model \
-p tests/outputs/act/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act_amp/ \
use_amp=true
test-act-ete-eval-amp:
python lerobot/scripts/eval.py \
-p tests/outputs/act_amp/checkpoints/000002/pretrained_model \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
use_amp=true
device=cpu \
test-diffusion-ete-train:
python lerobot/scripts/train.py \
policy=diffusion \
policy.down_dims=\[64,128,256\] \
policy.diffusion_step_embed_dim=32 \
policy.num_inference_steps=10 \
env=pusht \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=$(DEVICE) \
training.save_checkpoint=true \
device=cpu \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/
test-diffusion-ete-eval:
python lerobot/scripts/eval.py \
-p tests/outputs/diffusion/checkpoints/000002/pretrained_model \
-p tests/outputs/diffusion/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
device=cpu \
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
test-tdmpc-ete-train:
python lerobot/scripts/train.py \
policy=tdmpc \
@@ -119,23 +82,24 @@ test-tdmpc-ete-train:
dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
training.online_steps=2 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
device=$(DEVICE) \
training.save_checkpoint=true \
device=cpu \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
-p tests/outputs/tdmpc/checkpoints/000002 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
device=cpu \
test-default-ete-eval:
python lerobot/scripts/eval.py \
@@ -143,21 +107,4 @@ test-default-ete-eval:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=$(DEVICE) \
test-act-pusht-tutorial:
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
python lerobot/scripts/train.py \
policy=created_by_Makefile.yaml \
env=pusht \
wandb.enable=False \
training.offline_steps=2 \
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
device=$(DEVICE) \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act_pusht/
rm lerobot/configs/policy/created_by_Makefile.yaml
device=cpu \

View File

@@ -77,10 +77,6 @@ Install 🤗 LeRobot:
pip install .
```
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
you may need to install `cmake` and `build-essential` for building some of our dependencies.
On linux: `sudo apt-get install cmake build-essential`
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
- [aloha](https://github.com/huggingface/gym-aloha)
- [xarm](https://github.com/huggingface/gym-xarm)
@@ -103,7 +99,6 @@ wandb login
```
.
├── examples # contains demonstration examples, start here to learn about LeRobot
| └── advanced # contains even more examples for those who have mastered the basics
├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
@@ -127,21 +122,13 @@ wandb login
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
You can also locally visualize episodes from a dataset by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
```
or from a dataset in a local folder with the root `DATA_DIR` environment variable
```bash
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
```
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
@@ -149,51 +136,6 @@ https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-f
Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.
### The `LeRobotDataset` format
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and Pytorch dataset. For instance `dataset[0]` will retrieve a sample of the dataset observations and actions in pytorch tensors format ready to be fed to a model.
A specificity of `LeRobotDataset` is that we can retrieve several frames for one sample query. By setting `delta_timestamps` to a list of delta timestamps, e.g. `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for each query, 4 images including one at -1 second before the current time step, the two others at -0.5 second and -0.2, and the final one at the current time step (0 second). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states.
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
```
dataset attributes:
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
│ ├ observation.images.cam_high: VideoFrame
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp': float32 timestamp in the video}
│ ├ observation.state: List of float32: position of an arm joints (for instance)
│ ... (more observations)
│ ├ action: List of float32
│ ├ episode_index: int64: index of the episode for this sample
│ ├ frame_index: int64: index of the frame for this sample in the episode ; starts at 0 for each episode
│ ├ timestamp: float32: timestamp in the episode
│ ├ next.done: bool: indicates the end of en episode ; True for the last frame in each episode
│ └ index: int64: general index in the whole dataset
├ episode_data_index: contains 2 tensors with the start and end indices of each episode
│ ├ from: 1D int64 tensor of first frame index for each episode: shape (num episodes,) starts with 0
│ └ to: 1D int64 tensor of last frame index for each episode: shape (num episodes,)
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
│ ...
├ info: a dictionary of metadata on the dataset
│ ├ fps: float - frame per second the dataset is recorded/synchronized to
│ └ video: bool - indicates if frames are encoded in mp4 video files to save space or stored as png files
├ videos_dir: path to where the mp4 videos or png images are stored/accessed
└ camera_keys: List of string: the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
```
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
- hf_dataset stored using Hugging Face datasets library serialization to parquet
- videos are stored in mp4 format to save space or png files
- episode_data_index saved using `safetensor` tensor serializtion format
- stats saved using `safetensor` tensor serializtion format
- info are saved using JSON
Dataset can uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to you root dataset folder as illustrated in the above section on dataset visualization.
### Evaluate a pretrained policy
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
@@ -207,19 +149,18 @@ python lerobot/scripts/eval.py \
```
Note: After training your own policy, you can re-evaluate the checkpoints with:
```bash
python lerobot/scripts/eval.py -p {OUTPUT_DIR}/checkpoints/last/pretrained_model
python lerobot/scripts/eval.py \
-p PATH/TO/TRAIN/OUTPUT/FOLDER
```
See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
```bash
python lerobot/scripts/train.py \
policy=act \
@@ -233,19 +174,6 @@ The experiment directory is automatically generated and will show up in yellow i
hydra.run.dir=your/new/experiment/dir
```
In the experiment directory there will be a folder called `checkpoints` which will have the following structure:
```bash
checkpoints
├── 000250 # checkpoint_dir for training step 250
│ ├── pretrained_model # Hugging Face pretrained model dir
│ │ ├── config.json # Hugging Face pretrained model config
│ │ ├── config.yaml # consolidated Hydra config
│ │ ├── model.safetensors # model weights
│ │ └── README.md # Hugging Face model card
│ └── training_state.pth # optimizer/scheduler/rng state and training step
```
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
```bash
@@ -256,19 +184,7 @@ A link to the wandb logs for the run will also show up in yellow in your termina
![](media/wandb.png)
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
#### Reproduce state-of-the-art (SOTA)
We have organized our configuration files (found under [`lerobot/configs`](./lerobot/configs)) such that they reproduce SOTA results from a given model variant in their respective original works. Simply running:
```bash
python lerobot/scripts/train.py policy=diffusion env=pusht
```
reproduces SOTA results for Diffusion Policy on the PushT task.
Pretrained policies, along with reproduction details, can be found under the "Models" section of https://huggingface.co/lerobot.
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
## Contribute
@@ -299,14 +215,14 @@ If your dataset format is not supported, implement your own in `lerobot/common/d
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
To upload these to the hub, run the following:
```bash
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
```
See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.

View File

@@ -1,40 +0,0 @@
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
# Configure image
ARG PYTHON_VERSION=3.10
ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
git git-lfs openssh-client \
nano vim less util-linux \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Install gh cli tool
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
&& mkdir -p -m 755 /etc/apt/keyrings \
&& wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
&& apt update \
&& apt install gh -y \
&& apt clean && rm -rf /var/lib/apt/lists/*
# Setup `python`
RUN ln -s /usr/bin/python3 /usr/bin/python
# Install poetry
RUN curl -sSL https://install.python-poetry.org | python -
ENV PATH="/root/.local/bin:$PATH"
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
RUN poetry config virtualenvs.create false
RUN poetry config virtualenvs.in-project true
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@@ -4,15 +4,18 @@ FROM nvidia/cuda:12.4.1-base-ubuntu22.04
ARG PYTHON_VERSION=3.10
ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
git git-lfs openssh-client \
nano vim ffmpeg \
htop atop nvtop \
sed gawk grep curl wget \
tcpdump sysstat screen \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Create virtual environment
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
RUN python -m venv /opt/venv
@@ -20,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot
COPY . /lerobot
RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git
WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"

View File

@@ -1,183 +0,0 @@
This tutorial will explain the training script, how to use it, and particularly the use of Hydra to configure everything needed for the training run.
## The training script
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
- Loads a Hydra configuration file for the following steps (more on Hydra in a moment).
- Makes a simulation environment.
- Makes a dataset corresponding to that simulation environment.
- Makes a policy.
- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
## Basics of how we use Hydra
Explaining the ins and outs of [Hydra](https://hydra.cc/docs/intro/) is beyond the scope of this document, but here we'll share the main points you need to know.
First, `lerobot/configs` has a directory structure like this:
```
.
├── default.yaml
├── env
│ ├── aloha.yaml
│ ├── pusht.yaml
│ └── xarm.yaml
└── policy
├── act.yaml
├── diffusion.yaml
└── tdmpc.yaml
```
**_For brevity, in the rest of this document we'll drop the leading `lerobot/configs` path. So `default.yaml` really refers to `lerobot/configs/default.yaml`._**
When you run the training script with
```python
python lerobot/scripts/train.py
```
Hydra is set up to read `default.yaml` (via the `@hydra.main` decorator). If you take a look at the `@hydra.main`'s arguments you will see `config_path="../configs", config_name="default"`. At the top of `default.yaml`, is a `defaults` section which looks likes this:
```yaml
defaults:
- _self_
- env: pusht
- policy: diffusion
```
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
Thanks to this `defaults` section in `default.yaml`, if you want to train Diffusion Policy with PushT, you really only need to run:
```bash
python lerobot/scripts/train.py
```
However, you can be more explicit and launch the exact same Diffusion Policy training on PushT with:
```bash
python lerobot/scripts/train.py policy=diffusion env=pusht
```
This way of overriding defaults via the CLI is especially useful when you want to change the policy and/or environment. For instance, you can train ACT on the default Aloha environment with:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
There are two things to note here:
- Config overrides are passed as `param_name=param_value`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
## Overriding configuration parameters in the CLI
Now let's say that we want to train on a different task in the Aloha environment. If you look in `env/aloha.yaml` you will see something like:
```yaml
# lerobot/configs/env/aloha.yaml
env:
task: AlohaInsertion-v0
```
And if you look in `policy/act.yaml` you will see something like:
```yaml
# lerobot/configs/policy/act.yaml
dataset_repo_id: lerobot/aloha_sim_insertion_human
```
But our Aloha environment actually supports a cube transfer task as well. To train for this task, you could manually modify the two yaml configuration files respectively.
First, we'd need to switch to using the cube transfer task for the ALOHA environment.
```diff
# lerobot/configs/env/aloha.yaml
env:
- task: AlohaInsertion-v0
+ task: AlohaTransferCube-v0
```
Then, we'd also need to switch to using the cube transfer dataset.
```diff
# lerobot/configs/policy/act.yaml
-dataset_repo_id: lerobot/aloha_sim_insertion_human
+dataset_repo_id: lerobot/aloha_sim_transfer_cube_human
```
Then, you'd be able to run:
```bash
python lerobot/scripts/train.py policy=act env=aloha
```
and you'd be training and evaluating on the cube transfer task.
An alternative approach to editing the yaml configuration files, would be to override the defaults via the command line:
```bash
python lerobot/scripts/train.py \
policy=act \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
env=aloha \
env.task=AlohaTransferCube-v0
```
There's something new here. Notice the `.` delimiter used to traverse the configuration hierarchy. _But be aware that the `defaults` section is an exception. As you saw above, we didn't need to write `defaults.policy=act` in the CLI. `policy=act` was enough._
Putting all that knowledge together, here's the command that was used to train https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human.
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human \
device=cuda
env=aloha \
env.task=AlohaTransferCube-v0 \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
policy=act \
training.eval_freq=10000 \
training.log_freq=250 \
training.offline_steps=100000 \
training.save_model=true \
training.save_freq=25000 \
eval.n_episodes=50 \
eval.batch_size=50 \
wandb.enable=false \
```
There's one new thing here: `hydra.run.dir=outputs/train/act_aloha_sim_transfer_cube_human`, which specifies where to save the training output.
## Using a configuration file not in `lerobot/configs`
Above we discusses the our training script is set up such that Hydra looks for `default.yaml` in `lerobot/configs`. But, if you have a configuration file elsewhere in your filesystem you may use:
```bash
python lerobot/scripts/train.py --config-dir PARENT/PATH --config-name FILE_NAME_WITHOUT_EXTENSION
```
Note: here we use regular syntax for providing CLI arguments to a Python script, not Hydra's `param_name=param_value` syntax.
As a concrete example, this becomes particularly handy when you have a folder with training outputs, and would like to re-run the training. For example, say you previously ran the training script with one of the earlier commands and have `outputs/train/my_experiment/checkpoints/pretrained_model/config.yaml`. This `config.yaml` file will have the full set of configuration parameters within it. To run the training with the same configuration again, do:
```bash
python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpoints/last/pretrained_model --config-name config
```
Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).
---
So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
Please, head on over to our [advanced tutorial on adapting policy configuration to various environments](./advanced/train_act_pusht/train_act_pusht.md) to learn more.
Or in the meantime, happy coding! 🤗

View File

@@ -1,37 +0,0 @@
This tutorial explains how to resume a training run that you've started with the training script. If you don't know how our training script and configuration system works, please read [4_train_policy_with_script.md](./4_train_policy_with_script.md) first.
## Basic training resumption
Let's consider the example of training ACT for one of the ALOHA tasks. Here's a command that can achieve that:
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/run_resumption \
policy=act \
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human \
env=aloha \
env.task=AlohaTransferCube-v0 \
training.log_freq=25 \
training.save_checkpoint=true \
training.save_freq=100
```
Here we're using the default dataset and environment for ACT, and we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can test resumption. You should be able to see some logging and have a first checkpoint within 1 minute. Please interrupt the training after the first checkpoint.
To resume, all that we have to do is run the training script, providing the run directory, and the resume option:
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/run_resumption \
resume=true
```
You should see from the logging that your training picks up from where it left off.
Note that with `resume=true`, the configuration file from the last checkpoint in the training output directory is loaded. So it doesn't matter that we haven't provided all the other configuration parameters from our previous command (although there may be warnings to notify you that your command has a different configuration than than the checkpoint).
---
Now you should know how to resume your training run in case it gets interrupted or you want to extend a finished training run.
Happy coding! 🤗

View File

@@ -1,87 +0,0 @@
# @package _global_
# Change the seed to match what PushT eval uses
# (to avoid evaluating on seeds used for generating the training data).
seed: 100000
# Change the dataset repository to the PushT one.
dataset_repo_id: lerobot/pusht
override_dataset_stats:
observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_model: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.image: mean_std
# Use min_max normalization just because it's more standard.
observation.state: min_max
output_normalization_modes:
# Use min_max normalization just because it's more standard.
action: min_max
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -1,70 +0,0 @@
In this tutorial we will learn how to adapt a policy configuration to be compatible with a new environment and dataset. As a concrete example, we will adapt the default configuration for ACT to be compatible with the PushT environment and dataset.
If you haven't already read our tutorial on the [training script and configuration tooling](../4_train_policy_with_script.md) please do so prior to tackling this tutorial.
Let's get started!
Suppose we want to train ACT for PushT. Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
```bash
python lerobot/scripts/train.py policy=act env=pusht dataset_repo_id=lerobot/pusht
```
We need to adapt the parameters of the ACT policy configuration to the PushT environment. The most important ones are the image keys.
ALOHA's datasets and environments typically use a variable number of cameras. In `lerobot/configs/policy/act.yaml` you may notice two relevant sections. Here we show you the minimal diff needed to adjust to PushT:
```diff
override_dataset_stats:
- observation.images.top:
+ observation.image:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
policy:
input_shapes:
- observation.images.top: [3, 480, 640]
+ observation.image: [3, 96, 96]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
input_normalization_modes:
- observation.images.top: mean_std
+ observation.image: mean_std
observation.state: min_max
output_normalization_modes:
action: min_max
```
Here we've accounted for the following:
- PushT uses "observation.image" for its image key.
- PushT provides smaller images.
_Side note: technically we could override these via the CLI, but with many changes it gets a bit messy, and we also have a bit of a challenge in that we're using `.` in our observation keys which is treated by Hydra as a hierarchical separator_.
For your convenience, we provide [`act_pusht.yaml`](./act_pusht.yaml) in this directory. It contains the diff above, plus some other (optional) ones that are explained within. Please copy it into `lerobot/configs/policy` with:
```bash
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/act_pusht.yaml
```
(remember from a [previous tutorial](../4_train_policy_with_script.md) that Hydra will look in the `lerobot/configs` directory). Now try running the following.
<!-- Note to contributor: are you changing this command? Note that it's tested in `Makefile`, so change it there too! -->
```bash
python lerobot/scripts/train.py policy=act_pusht env=pusht
```
Notice that this is much the same as the command that failed at the start of the tutorial, only:
- Now we are using `policy=act_pusht` to point to our new configuration file.
- We can drop `dataset_repo_id=lerobot/pusht` as the change is incorporated in our new configuration file.
Hurrah! You're now training ACT for the PushT environment.
---
The bottom line of this tutorial is that when training policies for different environments and datasets you will need to understand what parts of the policy configuration are specific to those and make changes accordingly.
Happy coding! 🤗

View File

@@ -1,89 +0,0 @@
# Using `lerobot` on a real world arm
In this example, we'll be using `lerobot` on a real world arm to:
- record a dataset in the `lerobot` format
- (soon) train a policy on it
- (soon) run the policy in the real-world
## Which robotic arm to use
In this example we're using the [open-source low-cost arm from Alexander Koch](https://github.com/AlexanderKoch-Koch/low_cost_robot) in the specific setup of:
- having 6 servos per arm, i.e. using the elbow-to-wrist extension
- adding two cameras around it, one on top and one in the front
- having a teleoperation arm as well (build the leader and the follower arms in A. Koch repo, both with elbow-to-wrist extensions)
I'm using these cameras (but the setup should not be sensitive to the exact cameras you're using):
- C922 Pro Stream Webcam
- Intel(R) RealSense D455 (using only the RGB input)
In general, this example should be very easily extendable to any type of arm using Dynamixel servos with at least one camera by changing a couple of configuration in the gym env.
## Install the example
Follow these steps:
- install `lerobot`
- install the Dynamixel-sdk: `pip install dynamixel-sdk`
## Usage
### 0 - record examples
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
```
DATA_DIR='./data' python record_training_data.py \
--repo-id=thomwolf/blue_red_sort \
--num-episodes=50 \
--num-frames=400
```
TODO:
- various length episodes
- being able to drop episodes
- checking uploading to the hub
### 1 - visualize the dataset
Use the standard dataset visualization script pointing it to the right folder:
```
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
--repo-id thomwolf/blue_red_sort \
--episode-index 0
```
### 2 - Train a policy
From the example directory let's run this command to train a model using ACT
```
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
device=cuda \
hydra.searchpath=[file://./train_config/] \
hydra.run.dir=./outputs/train/blue_red_sort \
dataset_repo_id=thomwolf/blue_red_sort \
env=gym_real_world \
policy=act_real_world \
wandb.enable=false
```
### 3 - Evaluate the policy in the real world
From the example directory let's run this command to evaluate our policy.
The configuration for running the policy is in the checkpoint of the model.
You can override parameters as follow:
```
python run_policy.py \
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
env.episode_length=1000
```
## Convert a hdf5 dataset recorded with the original ACT repo
You can convert a dataset from the raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act with the following command:
```
python ./lerobot/scripts/push_dataset_to_hub.py
```

View File

@@ -1,840 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from safetensors.torch import load_file, save_file\n",
"from pprint import pprint"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n",
"converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n",
"\n",
"comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n",
"comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n",
"comparison_config_json_path = comparison_main_path + \"config.json\"\n",
"comparison_config_yaml_path = comparison_main_path + \"config.yaml\""
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"a = torch.load(original_ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"b = load_file(comparison_safetensor_path)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['model.action_head.bias',\n",
" 'model.action_head.weight',\n",
" 'model.backbone.bn1.bias',\n",
" 'model.backbone.bn1.running_mean',\n",
" 'model.backbone.bn1.running_var',\n",
" 'model.backbone.bn1.weight',\n",
" 'model.backbone.conv1.weight',\n",
" 'model.backbone.layer1.0.bn1.bias',\n",
" 'model.backbone.layer1.0.bn1.running_mean',\n",
" 'model.backbone.layer1.0.bn1.running_var',\n",
" 'model.backbone.layer1.0.bn1.weight',\n",
" 'model.backbone.layer1.0.bn2.bias',\n",
" 'model.backbone.layer1.0.bn2.running_mean',\n",
" 'model.backbone.layer1.0.bn2.running_var',\n",
" 'model.backbone.layer1.0.bn2.weight',\n",
" 'model.backbone.layer1.0.conv1.weight',\n",
" 'model.backbone.layer1.0.conv2.weight',\n",
" 'model.backbone.layer1.1.bn1.bias',\n",
" 'model.backbone.layer1.1.bn1.running_mean',\n",
" 'model.backbone.layer1.1.bn1.running_var',\n",
" 'model.backbone.layer1.1.bn1.weight',\n",
" 'model.backbone.layer1.1.bn2.bias',\n",
" 'model.backbone.layer1.1.bn2.running_mean',\n",
" 'model.backbone.layer1.1.bn2.running_var',\n",
" 'model.backbone.layer1.1.bn2.weight',\n",
" 'model.backbone.layer1.1.conv1.weight',\n",
" 'model.backbone.layer1.1.conv2.weight',\n",
" 'model.backbone.layer2.0.bn1.bias',\n",
" 'model.backbone.layer2.0.bn1.running_mean',\n",
" 'model.backbone.layer2.0.bn1.running_var',\n",
" 'model.backbone.layer2.0.bn1.weight',\n",
" 'model.backbone.layer2.0.bn2.bias',\n",
" 'model.backbone.layer2.0.bn2.running_mean',\n",
" 'model.backbone.layer2.0.bn2.running_var',\n",
" 'model.backbone.layer2.0.bn2.weight',\n",
" 'model.backbone.layer2.0.conv1.weight',\n",
" 'model.backbone.layer2.0.conv2.weight',\n",
" 'model.backbone.layer2.0.downsample.0.weight',\n",
" 'model.backbone.layer2.0.downsample.1.bias',\n",
" 'model.backbone.layer2.0.downsample.1.running_mean',\n",
" 'model.backbone.layer2.0.downsample.1.running_var',\n",
" 'model.backbone.layer2.0.downsample.1.weight',\n",
" 'model.backbone.layer2.1.bn1.bias',\n",
" 'model.backbone.layer2.1.bn1.running_mean',\n",
" 'model.backbone.layer2.1.bn1.running_var',\n",
" 'model.backbone.layer2.1.bn1.weight',\n",
" 'model.backbone.layer2.1.bn2.bias',\n",
" 'model.backbone.layer2.1.bn2.running_mean',\n",
" 'model.backbone.layer2.1.bn2.running_var',\n",
" 'model.backbone.layer2.1.bn2.weight',\n",
" 'model.backbone.layer2.1.conv1.weight',\n",
" 'model.backbone.layer2.1.conv2.weight',\n",
" 'model.backbone.layer3.0.bn1.bias',\n",
" 'model.backbone.layer3.0.bn1.running_mean',\n",
" 'model.backbone.layer3.0.bn1.running_var',\n",
" 'model.backbone.layer3.0.bn1.weight',\n",
" 'model.backbone.layer3.0.bn2.bias',\n",
" 'model.backbone.layer3.0.bn2.running_mean',\n",
" 'model.backbone.layer3.0.bn2.running_var',\n",
" 'model.backbone.layer3.0.bn2.weight',\n",
" 'model.backbone.layer3.0.conv1.weight',\n",
" 'model.backbone.layer3.0.conv2.weight',\n",
" 'model.backbone.layer3.0.downsample.0.weight',\n",
" 'model.backbone.layer3.0.downsample.1.bias',\n",
" 'model.backbone.layer3.0.downsample.1.running_mean',\n",
" 'model.backbone.layer3.0.downsample.1.running_var',\n",
" 'model.backbone.layer3.0.downsample.1.weight',\n",
" 'model.backbone.layer3.1.bn1.bias',\n",
" 'model.backbone.layer3.1.bn1.running_mean',\n",
" 'model.backbone.layer3.1.bn1.running_var',\n",
" 'model.backbone.layer3.1.bn1.weight',\n",
" 'model.backbone.layer3.1.bn2.bias',\n",
" 'model.backbone.layer3.1.bn2.running_mean',\n",
" 'model.backbone.layer3.1.bn2.running_var',\n",
" 'model.backbone.layer3.1.bn2.weight',\n",
" 'model.backbone.layer3.1.conv1.weight',\n",
" 'model.backbone.layer3.1.conv2.weight',\n",
" 'model.backbone.layer4.0.bn1.bias',\n",
" 'model.backbone.layer4.0.bn1.running_mean',\n",
" 'model.backbone.layer4.0.bn1.running_var',\n",
" 'model.backbone.layer4.0.bn1.weight',\n",
" 'model.backbone.layer4.0.bn2.bias',\n",
" 'model.backbone.layer4.0.bn2.running_mean',\n",
" 'model.backbone.layer4.0.bn2.running_var',\n",
" 'model.backbone.layer4.0.bn2.weight',\n",
" 'model.backbone.layer4.0.conv1.weight',\n",
" 'model.backbone.layer4.0.conv2.weight',\n",
" 'model.backbone.layer4.0.downsample.0.weight',\n",
" 'model.backbone.layer4.0.downsample.1.bias',\n",
" 'model.backbone.layer4.0.downsample.1.running_mean',\n",
" 'model.backbone.layer4.0.downsample.1.running_var',\n",
" 'model.backbone.layer4.0.downsample.1.weight',\n",
" 'model.backbone.layer4.1.bn1.bias',\n",
" 'model.backbone.layer4.1.bn1.running_mean',\n",
" 'model.backbone.layer4.1.bn1.running_var',\n",
" 'model.backbone.layer4.1.bn1.weight',\n",
" 'model.backbone.layer4.1.bn2.bias',\n",
" 'model.backbone.layer4.1.bn2.running_mean',\n",
" 'model.backbone.layer4.1.bn2.running_var',\n",
" 'model.backbone.layer4.1.bn2.weight',\n",
" 'model.backbone.layer4.1.conv1.weight',\n",
" 'model.backbone.layer4.1.conv2.weight',\n",
" 'model.decoder.layers.0.linear1.bias',\n",
" 'model.decoder.layers.0.linear1.weight',\n",
" 'model.decoder.layers.0.linear2.bias',\n",
" 'model.decoder.layers.0.linear2.weight',\n",
" 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n",
" 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n",
" 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n",
" 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n",
" 'model.decoder.layers.0.norm1.bias',\n",
" 'model.decoder.layers.0.norm1.weight',\n",
" 'model.decoder.layers.0.norm2.bias',\n",
" 'model.decoder.layers.0.norm2.weight',\n",
" 'model.decoder.layers.0.norm3.bias',\n",
" 'model.decoder.layers.0.norm3.weight',\n",
" 'model.decoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.decoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.decoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.decoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.decoder_pos_embed.weight',\n",
" 'model.encoder.layers.0.linear1.bias',\n",
" 'model.encoder.layers.0.linear1.weight',\n",
" 'model.encoder.layers.0.linear2.bias',\n",
" 'model.encoder.layers.0.linear2.weight',\n",
" 'model.encoder.layers.0.norm1.bias',\n",
" 'model.encoder.layers.0.norm1.weight',\n",
" 'model.encoder.layers.0.norm2.bias',\n",
" 'model.encoder.layers.0.norm2.weight',\n",
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.1.linear1.bias',\n",
" 'model.encoder.layers.1.linear1.weight',\n",
" 'model.encoder.layers.1.linear2.bias',\n",
" 'model.encoder.layers.1.linear2.weight',\n",
" 'model.encoder.layers.1.norm1.bias',\n",
" 'model.encoder.layers.1.norm1.weight',\n",
" 'model.encoder.layers.1.norm2.bias',\n",
" 'model.encoder.layers.1.norm2.weight',\n",
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.2.linear1.bias',\n",
" 'model.encoder.layers.2.linear1.weight',\n",
" 'model.encoder.layers.2.linear2.bias',\n",
" 'model.encoder.layers.2.linear2.weight',\n",
" 'model.encoder.layers.2.norm1.bias',\n",
" 'model.encoder.layers.2.norm1.weight',\n",
" 'model.encoder.layers.2.norm2.bias',\n",
" 'model.encoder.layers.2.norm2.weight',\n",
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.3.linear1.bias',\n",
" 'model.encoder.layers.3.linear1.weight',\n",
" 'model.encoder.layers.3.linear2.bias',\n",
" 'model.encoder.layers.3.linear2.weight',\n",
" 'model.encoder.layers.3.norm1.bias',\n",
" 'model.encoder.layers.3.norm1.weight',\n",
" 'model.encoder.layers.3.norm2.bias',\n",
" 'model.encoder.layers.3.norm2.weight',\n",
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.encoder_img_feat_input_proj.bias',\n",
" 'model.encoder_img_feat_input_proj.weight',\n",
" 'model.encoder_latent_input_proj.bias',\n",
" 'model.encoder_latent_input_proj.weight',\n",
" 'model.encoder_robot_and_latent_pos_embed.weight',\n",
" 'model.encoder_robot_state_input_proj.bias',\n",
" 'model.encoder_robot_state_input_proj.weight',\n",
" 'model.vae_encoder.layers.0.linear1.bias',\n",
" 'model.vae_encoder.layers.0.linear1.weight',\n",
" 'model.vae_encoder.layers.0.linear2.bias',\n",
" 'model.vae_encoder.layers.0.linear2.weight',\n",
" 'model.vae_encoder.layers.0.norm1.bias',\n",
" 'model.vae_encoder.layers.0.norm1.weight',\n",
" 'model.vae_encoder.layers.0.norm2.bias',\n",
" 'model.vae_encoder.layers.0.norm2.weight',\n",
" 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.1.linear1.bias',\n",
" 'model.vae_encoder.layers.1.linear1.weight',\n",
" 'model.vae_encoder.layers.1.linear2.bias',\n",
" 'model.vae_encoder.layers.1.linear2.weight',\n",
" 'model.vae_encoder.layers.1.norm1.bias',\n",
" 'model.vae_encoder.layers.1.norm1.weight',\n",
" 'model.vae_encoder.layers.1.norm2.bias',\n",
" 'model.vae_encoder.layers.1.norm2.weight',\n",
" 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.2.linear1.bias',\n",
" 'model.vae_encoder.layers.2.linear1.weight',\n",
" 'model.vae_encoder.layers.2.linear2.bias',\n",
" 'model.vae_encoder.layers.2.linear2.weight',\n",
" 'model.vae_encoder.layers.2.norm1.bias',\n",
" 'model.vae_encoder.layers.2.norm1.weight',\n",
" 'model.vae_encoder.layers.2.norm2.bias',\n",
" 'model.vae_encoder.layers.2.norm2.weight',\n",
" 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.3.linear1.bias',\n",
" 'model.vae_encoder.layers.3.linear1.weight',\n",
" 'model.vae_encoder.layers.3.linear2.bias',\n",
" 'model.vae_encoder.layers.3.linear2.weight',\n",
" 'model.vae_encoder.layers.3.norm1.bias',\n",
" 'model.vae_encoder.layers.3.norm1.weight',\n",
" 'model.vae_encoder.layers.3.norm2.bias',\n",
" 'model.vae_encoder.layers.3.norm2.weight',\n",
" 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.vae_encoder_action_input_proj.bias',\n",
" 'model.vae_encoder_action_input_proj.weight',\n",
" 'model.vae_encoder_cls_embed.weight',\n",
" 'model.vae_encoder_latent_output_proj.bias',\n",
" 'model.vae_encoder_latent_output_proj.weight',\n",
" 'model.vae_encoder_pos_enc',\n",
" 'model.vae_encoder_robot_state_input_proj.bias',\n",
" 'model.vae_encoder_robot_state_input_proj.weight',\n",
" 'normalize_inputs.buffer_observation_images_front.mean',\n",
" 'normalize_inputs.buffer_observation_images_front.std',\n",
" 'normalize_inputs.buffer_observation_images_top.mean',\n",
" 'normalize_inputs.buffer_observation_images_top.std',\n",
" 'normalize_inputs.buffer_observation_state.mean',\n",
" 'normalize_inputs.buffer_observation_state.std',\n",
" 'normalize_targets.buffer_action.mean',\n",
" 'normalize_targets.buffer_action.std',\n",
" 'unnormalize_outputs.buffer_action.mean',\n",
" 'unnormalize_outputs.buffer_action.std']\n"
]
}
],
"source": [
"dest = list(b.keys())\n",
"pprint(dest)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['model.pos_table',\n",
" 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.0.linear1.weight',\n",
" 'model.transformer.encoder.layers.0.linear1.bias',\n",
" 'model.transformer.encoder.layers.0.linear2.weight',\n",
" 'model.transformer.encoder.layers.0.linear2.bias',\n",
" 'model.transformer.encoder.layers.0.norm1.weight',\n",
" 'model.transformer.encoder.layers.0.norm1.bias',\n",
" 'model.transformer.encoder.layers.0.norm2.weight',\n",
" 'model.transformer.encoder.layers.0.norm2.bias',\n",
" 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.1.linear1.weight',\n",
" 'model.transformer.encoder.layers.1.linear1.bias',\n",
" 'model.transformer.encoder.layers.1.linear2.weight',\n",
" 'model.transformer.encoder.layers.1.linear2.bias',\n",
" 'model.transformer.encoder.layers.1.norm1.weight',\n",
" 'model.transformer.encoder.layers.1.norm1.bias',\n",
" 'model.transformer.encoder.layers.1.norm2.weight',\n",
" 'model.transformer.encoder.layers.1.norm2.bias',\n",
" 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.2.linear1.weight',\n",
" 'model.transformer.encoder.layers.2.linear1.bias',\n",
" 'model.transformer.encoder.layers.2.linear2.weight',\n",
" 'model.transformer.encoder.layers.2.linear2.bias',\n",
" 'model.transformer.encoder.layers.2.norm1.weight',\n",
" 'model.transformer.encoder.layers.2.norm1.bias',\n",
" 'model.transformer.encoder.layers.2.norm2.weight',\n",
" 'model.transformer.encoder.layers.2.norm2.bias',\n",
" 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.3.linear1.weight',\n",
" 'model.transformer.encoder.layers.3.linear1.bias',\n",
" 'model.transformer.encoder.layers.3.linear2.weight',\n",
" 'model.transformer.encoder.layers.3.linear2.bias',\n",
" 'model.transformer.encoder.layers.3.norm1.weight',\n",
" 'model.transformer.encoder.layers.3.norm1.bias',\n",
" 'model.transformer.encoder.layers.3.norm2.weight',\n",
" 'model.transformer.encoder.layers.3.norm2.bias',\n",
" 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.0.linear1.weight',\n",
" 'model.transformer.decoder.layers.0.linear1.bias',\n",
" 'model.transformer.decoder.layers.0.linear2.weight',\n",
" 'model.transformer.decoder.layers.0.linear2.bias',\n",
" 'model.transformer.decoder.layers.0.norm1.weight',\n",
" 'model.transformer.decoder.layers.0.norm1.bias',\n",
" 'model.transformer.decoder.layers.0.norm2.weight',\n",
" 'model.transformer.decoder.layers.0.norm2.bias',\n",
" 'model.transformer.decoder.layers.0.norm3.weight',\n",
" 'model.transformer.decoder.layers.0.norm3.bias',\n",
" 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.1.linear1.weight',\n",
" 'model.transformer.decoder.layers.1.linear1.bias',\n",
" 'model.transformer.decoder.layers.1.linear2.weight',\n",
" 'model.transformer.decoder.layers.1.linear2.bias',\n",
" 'model.transformer.decoder.layers.1.norm1.weight',\n",
" 'model.transformer.decoder.layers.1.norm1.bias',\n",
" 'model.transformer.decoder.layers.1.norm2.weight',\n",
" 'model.transformer.decoder.layers.1.norm2.bias',\n",
" 'model.transformer.decoder.layers.1.norm3.weight',\n",
" 'model.transformer.decoder.layers.1.norm3.bias',\n",
" 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.2.linear1.weight',\n",
" 'model.transformer.decoder.layers.2.linear1.bias',\n",
" 'model.transformer.decoder.layers.2.linear2.weight',\n",
" 'model.transformer.decoder.layers.2.linear2.bias',\n",
" 'model.transformer.decoder.layers.2.norm1.weight',\n",
" 'model.transformer.decoder.layers.2.norm1.bias',\n",
" 'model.transformer.decoder.layers.2.norm2.weight',\n",
" 'model.transformer.decoder.layers.2.norm2.bias',\n",
" 'model.transformer.decoder.layers.2.norm3.weight',\n",
" 'model.transformer.decoder.layers.2.norm3.bias',\n",
" 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.3.linear1.weight',\n",
" 'model.transformer.decoder.layers.3.linear1.bias',\n",
" 'model.transformer.decoder.layers.3.linear2.weight',\n",
" 'model.transformer.decoder.layers.3.linear2.bias',\n",
" 'model.transformer.decoder.layers.3.norm1.weight',\n",
" 'model.transformer.decoder.layers.3.norm1.bias',\n",
" 'model.transformer.decoder.layers.3.norm2.weight',\n",
" 'model.transformer.decoder.layers.3.norm2.bias',\n",
" 'model.transformer.decoder.layers.3.norm3.weight',\n",
" 'model.transformer.decoder.layers.3.norm3.bias',\n",
" 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.4.linear1.weight',\n",
" 'model.transformer.decoder.layers.4.linear1.bias',\n",
" 'model.transformer.decoder.layers.4.linear2.weight',\n",
" 'model.transformer.decoder.layers.4.linear2.bias',\n",
" 'model.transformer.decoder.layers.4.norm1.weight',\n",
" 'model.transformer.decoder.layers.4.norm1.bias',\n",
" 'model.transformer.decoder.layers.4.norm2.weight',\n",
" 'model.transformer.decoder.layers.4.norm2.bias',\n",
" 'model.transformer.decoder.layers.4.norm3.weight',\n",
" 'model.transformer.decoder.layers.4.norm3.bias',\n",
" 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.5.linear1.weight',\n",
" 'model.transformer.decoder.layers.5.linear1.bias',\n",
" 'model.transformer.decoder.layers.5.linear2.weight',\n",
" 'model.transformer.decoder.layers.5.linear2.bias',\n",
" 'model.transformer.decoder.layers.5.norm1.weight',\n",
" 'model.transformer.decoder.layers.5.norm1.bias',\n",
" 'model.transformer.decoder.layers.5.norm2.weight',\n",
" 'model.transformer.decoder.layers.5.norm2.bias',\n",
" 'model.transformer.decoder.layers.5.norm3.weight',\n",
" 'model.transformer.decoder.layers.5.norm3.bias',\n",
" 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.6.linear1.weight',\n",
" 'model.transformer.decoder.layers.6.linear1.bias',\n",
" 'model.transformer.decoder.layers.6.linear2.weight',\n",
" 'model.transformer.decoder.layers.6.linear2.bias',\n",
" 'model.transformer.decoder.layers.6.norm1.weight',\n",
" 'model.transformer.decoder.layers.6.norm1.bias',\n",
" 'model.transformer.decoder.layers.6.norm2.weight',\n",
" 'model.transformer.decoder.layers.6.norm2.bias',\n",
" 'model.transformer.decoder.layers.6.norm3.weight',\n",
" 'model.transformer.decoder.layers.6.norm3.bias',\n",
" 'model.transformer.decoder.norm.weight',\n",
" 'model.transformer.decoder.norm.bias',\n",
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.0.linear1.weight',\n",
" 'model.encoder.layers.0.linear1.bias',\n",
" 'model.encoder.layers.0.linear2.weight',\n",
" 'model.encoder.layers.0.linear2.bias',\n",
" 'model.encoder.layers.0.norm1.weight',\n",
" 'model.encoder.layers.0.norm1.bias',\n",
" 'model.encoder.layers.0.norm2.weight',\n",
" 'model.encoder.layers.0.norm2.bias',\n",
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.1.linear1.weight',\n",
" 'model.encoder.layers.1.linear1.bias',\n",
" 'model.encoder.layers.1.linear2.weight',\n",
" 'model.encoder.layers.1.linear2.bias',\n",
" 'model.encoder.layers.1.norm1.weight',\n",
" 'model.encoder.layers.1.norm1.bias',\n",
" 'model.encoder.layers.1.norm2.weight',\n",
" 'model.encoder.layers.1.norm2.bias',\n",
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.2.linear1.weight',\n",
" 'model.encoder.layers.2.linear1.bias',\n",
" 'model.encoder.layers.2.linear2.weight',\n",
" 'model.encoder.layers.2.linear2.bias',\n",
" 'model.encoder.layers.2.norm1.weight',\n",
" 'model.encoder.layers.2.norm1.bias',\n",
" 'model.encoder.layers.2.norm2.weight',\n",
" 'model.encoder.layers.2.norm2.bias',\n",
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.3.linear1.weight',\n",
" 'model.encoder.layers.3.linear1.bias',\n",
" 'model.encoder.layers.3.linear2.weight',\n",
" 'model.encoder.layers.3.linear2.bias',\n",
" 'model.encoder.layers.3.norm1.weight',\n",
" 'model.encoder.layers.3.norm1.bias',\n",
" 'model.encoder.layers.3.norm2.weight',\n",
" 'model.encoder.layers.3.norm2.bias',\n",
" 'model.action_head.weight',\n",
" 'model.action_head.bias',\n",
" 'model.is_pad_head.weight',\n",
" 'model.is_pad_head.bias',\n",
" 'model.query_embed.weight',\n",
" 'model.input_proj.weight',\n",
" 'model.input_proj.bias',\n",
" 'model.backbones.0.0.body.conv1.weight',\n",
" 'model.backbones.0.0.body.bn1.weight',\n",
" 'model.backbones.0.0.body.bn1.bias',\n",
" 'model.backbones.0.0.body.bn1.running_mean',\n",
" 'model.backbones.0.0.body.bn1.running_var',\n",
" 'model.backbones.0.0.body.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n",
" 'model.input_proj_robot_state.weight',\n",
" 'model.input_proj_robot_state.bias',\n",
" 'model.cls_embed.weight',\n",
" 'model.encoder_action_proj.weight',\n",
" 'model.encoder_action_proj.bias',\n",
" 'model.encoder_joint_proj.weight',\n",
" 'model.encoder_joint_proj.bias',\n",
" 'model.latent_proj.weight',\n",
" 'model.latent_proj.bias',\n",
" 'model.latent_out_proj.weight',\n",
" 'model.latent_out_proj.bias',\n",
" 'model.additional_pos_embed.weight']\n"
]
}
],
"source": [
"orig = list(a.keys())\n",
"pprint(orig)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"a = torch.load(original_ckpt_path)\n",
"\n",
"to_remove_startswith = ['model.transformer.decoder.layers.1.',\n",
" 'model.transformer.decoder.layers.2.',\n",
" 'model.transformer.decoder.layers.3.',\n",
" 'model.transformer.decoder.layers.4.',\n",
" 'model.transformer.decoder.layers.5.',\n",
" 'model.transformer.decoder.layers.6.',\n",
" 'model.transformer.decoder.norm.',\n",
" 'model.is_pad_head']\n",
"\n",
"to_remove_in = ['num_batches_tracked',]\n",
"\n",
"conv = {}\n",
"\n",
"keys = list(a.keys())\n",
"for k in keys:\n",
" if any(k.startswith(tr) for tr in to_remove_startswith):\n",
" a.pop(k)\n",
" continue\n",
" if any(tr in k for tr in to_remove_in):\n",
" a.pop(k)\n",
" continue\n",
" if k.startswith('model.transformer.encoder.layers.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.transformer.decoder.layers.0.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.encoder.layers.'):\n",
" conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n",
" if k.startswith('model.action_head.'):\n",
" conv[k] = a.pop(k)\n",
" if k.startswith('model.pos_table'):\n",
" conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n",
" if k.startswith('model.query_embed.'):\n",
" conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n",
" if k.startswith('model.input_proj.'):\n",
" conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n",
" if k.startswith('model.input_proj_robot_state.'):\n",
" conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n",
" if k.startswith('model.backbones.0.0.body.'):\n",
" conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n",
" if k.startswith('model.cls_embed.'):\n",
" conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n",
" if k.startswith('model.encoder_action_proj.'):\n",
" conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n",
" if k.startswith('model.encoder_joint_proj.'):\n",
" conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n",
" if k.startswith('model.latent_proj.'):\n",
" conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n",
" if k.startswith('model.latent_out_proj.'):\n",
" conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n",
" if k.startswith('model.additional_pos_embed.'):\n",
" conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict()"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"for k, v in conv.items():\n",
" assert b[k].shape == v.shape\n",
" b[k] = v"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"save_file(b, converted_ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Now also copy the config files\n",
"import shutil\n",
"shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n",
"shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,8 +0,0 @@
from gymnasium.envs.registration import register
register(
id="gym_real_world/RealEnv-v0",
entry_point="gym_real_world.gym_environment:RealEnv",
max_episode_steps=300,
nondeterministic=True,
)

View File

@@ -1,363 +0,0 @@
# ruff: noqa
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
Dynamixel class to control the dynamixel servos
"""
from __future__ import annotations
import enum
import math
import os
from dataclasses import dataclass
import numpy as np
from dynamixel_sdk import * # Uses Dynamixel SDK library
def pos2pwm(pos: np.ndarray) -> np.ndarray:
"""
:param pos: numpy array of joint positions in range [-pi, pi]
:return: numpy array of pwm values in range [0, 4096]
"""
return ((pos / 3.14 + 1.0) * 2048).astype(np.int64)
def pwm2pos(pwm: np.ndarray) -> np.ndarray:
"""
:param pwm: numpy array of pwm values in range [0, 4096]
:return: numpy array of joint positions in range [-pi, pi]
"""
return (pwm / 2048 - 1) * 3.14
def pwm2vel(pwm: np.ndarray) -> np.ndarray:
"""
:param pwm: numpy array of pwm/s joint velocities
:return: numpy array of rad/s joint velocities
"""
return pwm * 3.14 / 2048
def vel2pwm(vel: np.ndarray) -> np.ndarray:
"""
:param vel: numpy array of rad/s joint velocities
:return: numpy array of pwm/s joint velocities
"""
return (vel * 2048 / 3.14).astype(np.int64)
class ReadAttribute(enum.Enum):
TEMPERATURE = 146
VOLTAGE = 145
VELOCITY = 128
POSITION = 132
CURRENT = 126
PWM = 124
HARDWARE_ERROR_STATUS = 70
HOMING_OFFSET = 20
BAUDRATE = 8
class OperatingMode(enum.Enum):
VELOCITY = 1
POSITION = 3
CURRENT_CONTROLLED_POSITION = 5
PWM = 16
UNKNOWN = -1
class Dynamixel:
ADDR_TORQUE_ENABLE = 64
ADDR_GOAL_POSITION = 116
ADDR_VELOCITY_LIMIT = 44
ADDR_GOAL_PWM = 100
OPERATING_MODE_ADDR = 11
POSITION_I = 82
POSITION_P = 84
ADDR_ID = 7
@dataclass
class Config:
def instantiate(self):
return Dynamixel(self)
baudrate: int = 57600
protocol_version: float = 2.0
device_name: str = "" # /dev/tty.usbserial-1120'
dynamixel_id: int = 1
def __init__(self, config: Config):
self.config = config
self.connect()
def connect(self):
if self.config.device_name == "":
for port_name in os.listdir("/dev"):
if "ttyUSB" in port_name or "ttyACM" in port_name:
self.config.device_name = "/dev/" + port_name
print(f"using device {self.config.device_name}")
self.portHandler = PortHandler(self.config.device_name)
# self.portHandler.LA
self.packetHandler = PacketHandler(self.config.protocol_version)
if not self.portHandler.openPort():
raise Exception(f"Failed to open port {self.config.device_name}")
if not self.portHandler.setBaudRate(self.config.baudrate):
raise Exception(f"failed to set baudrate to {self.config.baudrate}")
# self.operating_mode = OperatingMode.UNKNOWN
# self.torque_enabled = False
# self._disable_torque()
self.operating_modes = [None for _ in range(32)]
self.torque_enabled = [None for _ in range(32)]
return True
def disconnect(self):
self.portHandler.closePort()
def set_goal_position(self, motor_id, goal_position):
# if self.operating_modes[motor_id] is not OperatingMode.POSITION:
# self._disable_torque(motor_id)
# self.set_operating_mode(motor_id, OperatingMode.POSITION)
# if not self.torque_enabled[motor_id]:
# self._enable_torque(motor_id)
# self._enable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, self.ADDR_GOAL_POSITION, goal_position
)
# self._process_response(dxl_comm_result, dxl_error)
# print(f'set position of motor {motor_id} to {goal_position}')
def set_pwm_value(self, motor_id: int, pwm_value, tries=3):
if self.operating_modes[motor_id] is not OperatingMode.PWM:
self._disable_torque(motor_id)
self.set_operating_mode(motor_id, OperatingMode.PWM)
if not self.torque_enabled[motor_id]:
self._enable_torque(motor_id)
# print(f'enabling torque')
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.ADDR_GOAL_PWM, pwm_value
)
# self._process_response(dxl_comm_result, dxl_error)
# print(f'set pwm of motor {motor_id} to {pwm_value}')
if dxl_comm_result != COMM_SUCCESS:
if tries <= 1:
raise ConnectionError(f"dxl_comm_result: {self.packetHandler.getTxRxResult(dxl_comm_result)}")
else:
print(f"dynamixel pwm setting failure trying again with {tries - 1} tries")
self.set_pwm_value(motor_id, pwm_value, tries=tries - 1)
elif dxl_error != 0:
print(f"dxl error {dxl_error}")
raise ConnectionError(f"dynamixel error: {self.packetHandler.getTxRxResult(dxl_error)}")
def read_temperature(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.TEMPERATURE, 1)
def read_velocity(self, motor_id: int):
pos = self._read_value(motor_id, ReadAttribute.VELOCITY, 4)
if pos > 2**31:
pos -= 2**32
# print(f'read position {pos} for motor {motor_id}')
return pos
def read_position(self, motor_id: int):
pos = self._read_value(motor_id, ReadAttribute.POSITION, 4)
if pos > 2**31:
pos -= 2**32
# print(f'read position {pos} for motor {motor_id}')
return pos
def read_position_degrees(self, motor_id: int) -> float:
return (self.read_position(motor_id) / 4096) * 360
def read_position_radians(self, motor_id: int) -> float:
return (self.read_position(motor_id) / 4096) * 2 * math.pi
def read_current(self, motor_id: int):
current = self._read_value(motor_id, ReadAttribute.CURRENT, 2)
if current > 2**15:
current -= 2**16
return current
def read_present_pwm(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.PWM, 2)
def read_hardware_error_status(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.HARDWARE_ERROR_STATUS, 1)
def disconnect(self):
self.portHandler.closePort()
def set_id(self, old_id, new_id, use_broadcast_id: bool = False):
"""
sets the id of the dynamixel servo
@param old_id: current id of the servo
@param new_id: new id
@param use_broadcast_id: set ids of all connected dynamixels if True.
If False, change only servo with self.config.id
@return:
"""
if use_broadcast_id:
current_id = 254
else:
current_id = old_id
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, current_id, self.ADDR_ID, new_id
)
self._process_response(dxl_comm_result, dxl_error, old_id)
self.config.id = id
def _enable_torque(self, motor_id):
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 1
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.torque_enabled[motor_id] = True
def _disable_torque(self, motor_id):
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 0
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.torque_enabled[motor_id] = False
def _process_response(self, dxl_comm_result: int, dxl_error: int, motor_id: int):
if dxl_comm_result != COMM_SUCCESS:
raise ConnectionError(
f"dxl_comm_result for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_comm_result)}"
)
elif dxl_error != 0:
print(f"dxl error {dxl_error}")
raise ConnectionError(
f"dynamixel error for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_error)}"
)
def set_operating_mode(self, motor_id: int, operating_mode: OperatingMode):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.OPERATING_MODE_ADDR, operating_mode.value
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.operating_modes[motor_id] = operating_mode
def set_pwm_limit(self, motor_id: int, limit: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(self.portHandler, motor_id, 36, limit)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_velocity_limit(self, motor_id: int, velocity_limit):
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, self.ADDR_VELOCITY_LIMIT, velocity_limit
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_P(self, motor_id: int, P: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.POSITION_P, P
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_I(self, motor_id: int, I: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.POSITION_I, I
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def read_home_offset(self, motor_id: int):
self._disable_torque(motor_id)
# dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(self.portHandler, motor_id,
# ReadAttribute.HOMING_OFFSET.value, home_position)
home_offset = self._read_value(motor_id, ReadAttribute.HOMING_OFFSET, 4)
# self._process_response(dxl_comm_result, dxl_error)
self._enable_torque(motor_id)
return home_offset
def set_home_offset(self, motor_id: int, home_position: int):
self._disable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, ReadAttribute.HOMING_OFFSET.value, home_position
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self._enable_torque(motor_id)
def set_baudrate(self, motor_id: int, baudrate):
# translate baudrate into dynamixel baudrate setting id
if baudrate == 57600:
baudrate_id = 1
elif baudrate == 1_000_000:
baudrate_id = 3
elif baudrate == 2_000_000:
baudrate_id = 4
elif baudrate == 3_000_000:
baudrate_id = 5
elif baudrate == 4_000_000:
baudrate_id = 6
else:
raise Exception("baudrate not implemented")
self._disable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, ReadAttribute.BAUDRATE.value, baudrate_id
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def _read_value(self, motor_id, attribute: ReadAttribute, num_bytes: int, tries=10):
try:
if num_bytes == 1:
value, dxl_comm_result, dxl_error = self.packetHandler.read1ByteTxRx(
self.portHandler, motor_id, attribute.value
)
elif num_bytes == 2:
value, dxl_comm_result, dxl_error = self.packetHandler.read2ByteTxRx(
self.portHandler, motor_id, attribute.value
)
elif num_bytes == 4:
value, dxl_comm_result, dxl_error = self.packetHandler.read4ByteTxRx(
self.portHandler, motor_id, attribute.value
)
except Exception:
if tries == 0:
raise Exception
else:
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
if dxl_comm_result != COMM_SUCCESS:
if tries <= 1:
# print("%s" % self.packetHandler.getTxRxResult(dxl_comm_result))
raise ConnectionError(f"dxl_comm_result {dxl_comm_result} for servo {motor_id} value {value}")
else:
print(f"dynamixel read failure for servo {motor_id} trying again with {tries - 1} tries")
time.sleep(0.02)
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
elif dxl_error != 0: # # print("%s" % self.packetHandler.getRxPacketError(dxl_error))
# raise ConnectionError(f'dxl_error {dxl_error} binary ' + "{0:b}".format(37))
if tries == 0 and dxl_error != 128:
raise Exception(f"Failed to read value from motor {motor_id} error is {dxl_error}")
else:
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
return value
def set_home_position(self, motor_id: int):
print(f"setting home position for motor {motor_id}")
self.set_home_offset(motor_id, 0)
current_position = self.read_position(motor_id)
print(f"position before {current_position}")
self.set_home_offset(motor_id, -current_position)
# dynamixel.set_home_offset(motor_id, -4096)
# dynamixel.set_home_offset(motor_id, -4294964109)
current_position = self.read_position(motor_id)
# print(f'signed position {current_position - 2** 32}')
print(f"position after {current_position}")
if __name__ == "__main__":
dynamixel = Dynamixel.Config(baudrate=1_000_000, device_name="/dev/tty.usbmodem57380045631").instantiate()
motor_id = 1
pos = dynamixel.read_position(motor_id)
for i in range(10):
s = time.monotonic()
pos = dynamixel.read_position(motor_id)
delta = time.monotonic() - s
print(f"read position took {delta}")
print(f"position {pos}")

View File

@@ -1,192 +0,0 @@
import time
from unittest.mock import MagicMock
import cv2
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from .dynamixel import pos2pwm, pwm2pos
from .robot import Robot
FPS = 30
CAMERAS_SHAPES = {
"images.high": (480, 640, 3),
"images.low": (480, 640, 3),
}
CAMERAS_PORTS = {
"images.high": "/dev/video6",
"images.low": "/dev/video0",
}
LEADER_PORT = "/dev/ttyACM1"
FOLLOWER_PORT = "/dev/ttyACM0"
MockRobot = MagicMock()
MockRobot.read_position = MagicMock()
MockRobot.read_position.return_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
MockCamera = MagicMock()
MockCamera.isOpened = MagicMock(return_value=True)
MockCamera.read = MagicMock(return_value=(True, np.zeros((480, 640, 3), dtype=np.uint8)))
def capture_image(cam, cam_width, cam_height):
# Capture a single frame
_, frame = cam.read()
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# # Define your crop coordinates (top left corner and bottom right corner)
# x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
# x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
# # Crop the image
# image = image[y1:y2, x1:x2]
# Resize the image
image = cv2.resize(image, (cam_width, cam_height), interpolation=cv2.INTER_AREA)
return image
class RealEnv(gym.Env):
metadata = {}
def __init__(
self,
record: bool = False,
num_joints: int = 6,
cameras_shapes: dict = CAMERAS_SHAPES,
cameras_ports: dict = CAMERAS_PORTS,
follower_port: str = FOLLOWER_PORT,
leader_port: str = LEADER_PORT,
warmup_steps: int = 100,
trigger_torque=70,
fps: int = FPS,
fps_tolerance: float = 0.1,
mock: bool = False,
):
self.num_joints = num_joints
self.cameras_shapes = cameras_shapes
self.cameras_ports = cameras_ports
self.warmup_steps = warmup_steps
assert len(self.cameras_shapes) == len(self.cameras_ports), "Number of cameras and shapes must match."
self.follower_port = follower_port
self.leader_port = leader_port
self.record = record
self.fps = fps
self.fps_tolerance = fps_tolerance
# Initialize the robot
self.follower = Robot(device_name=self.follower_port) if not mock else MockRobot
if self.record:
self.leader = Robot(device_name=self.leader_port) if not mock else MockRobot
self.leader.set_trigger_torque(trigger_torque)
# Initialize the cameras - sorted by camera names
self.cameras = {}
for cn, p in sorted(self.cameras_ports.items()):
self.cameras[cn] = cv2.VideoCapture(p) if not mock else MockCamera
if not self.cameras[cn].isOpened():
raise OSError(
f"Cannot open camera port {p} for {cn}."
f" Make sure the camera is connected and the port is correct."
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
)
# Specify gym action and observation spaces
observation_space = {}
if self.num_joints > 0:
observation_space["agent_pos"] = spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
)
if self.record:
observation_space["leader_pos"] = spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
)
if self.cameras_shapes:
for cn, hwc_shape in self.cameras_shapes.items():
# Assumes images are unsigned int8 in [0,255]
observation_space[cn] = spaces.Box(
low=0,
high=255,
# height x width x channels (e.g. 480 x 640 x 3)
shape=hwc_shape,
dtype=np.uint8,
)
self.observation_space = spaces.Dict(observation_space)
self.action_space = spaces.Box(low=-1, high=1, shape=(num_joints,), dtype=np.float32)
self._observation = {}
self._terminated = False
self.timestamps = []
def _get_obs(self):
qpos = self.follower.read_position()
self._observation["agent_pos"] = pwm2pos(qpos)
for cn, c in self.cameras.items():
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
if self.record:
action = self.leader.read_position()
self._observation["leader_pos"] = pwm2pos(action)
def reset(self, seed: int | None = None):
# Reset the robot and sync the leader and follower if we are recording
for _ in range(self.warmup_steps):
self._get_obs()
if self.record:
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
self._terminated = False
info = {}
self.timestamps = []
return self._observation, info
def step(self, action: np.ndarray = None):
if self.timestamps:
# wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
self.timestamps.append(time.time())
# Get the observation
self._get_obs()
if self.record:
# Teleoperate the leader
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
else:
# Apply the action to the follower
self.follower.set_goal_pos(pos2pwm(action))
reward = 0
terminated = truncated = self._terminated
info = {"timestamp": self.timestamps[-1] - self.timestamps[0], "fps_error": False}
# Check if we are able to keep up with the desired fps
if len(self.timestamps) > 1 and (self.timestamps[-1] - self.timestamps[-2]) > 1 / (
self.fps - self.fps_tolerance
):
print(
f"Error: recording fps {1 / (self.timestamps[-1] - self.timestamps[-2]):.5f} is lower"
f" than min admited fps {(self.fps - self.fps_tolerance):.5f}"
f" at frame {len(self.timestamps)}"
)
info["fps_error"] = True
return self._observation, reward, terminated, truncated, info
def render(self): ...
def close(self):
self.follower._disable_torque()
if self.record:
self.leader._disable_torque()

View File

@@ -1,173 +0,0 @@
# ruff: noqa
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
Class to control the robot using dynamixel servos.
"""
from enum import Enum, auto
from typing import Union
import numpy as np
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD, GroupSyncRead, GroupSyncWrite
from .dynamixel import Dynamixel, OperatingMode, ReadAttribute
class MotorControlType(Enum):
PWM = auto()
POSITION_CONTROL = auto()
DISABLED = auto()
UNKNOWN = auto()
class Robot:
def __init__(self, device_name: str, baudrate=1_000_000, servo_ids=[1, 2, 3, 4, 5, 6]) -> None:
self.servo_ids = servo_ids
self.dynamixel = Dynamixel.Config(baudrate=baudrate, device_name=device_name).instantiate()
self._init_motors()
def _init_motors(self):
self.position_reader = GroupSyncRead(
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.POSITION.value, 4
)
for id in self.servo_ids:
self.position_reader.addParam(id)
self.velocity_reader = GroupSyncRead(
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.VELOCITY.value, 4
)
for id in self.servo_ids:
self.velocity_reader.addParam(id)
self.pos_writer = GroupSyncWrite(
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_POSITION, 4
)
for id in self.servo_ids:
self.pos_writer.addParam(id, [2048])
self.pwm_writer = GroupSyncWrite(
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_PWM, 2
)
for id in self.servo_ids:
self.pwm_writer.addParam(id, [2048])
# self._disable_torque()
self.motor_control_state = MotorControlType.DISABLED
def read_position(self, tries=2):
"""
Reads the joint positions of the robot. 2048 is the center position. 0 and 4096 are 180 degrees in each direction.
:param tries: maximum number of tries to read the position
:return: list of joint positions in range [0, 4096]
"""
result = self.position_reader.txRxPacket()
if result != 0:
if tries > 0:
return self.read_position(tries=tries - 1)
else:
print("failed to read position!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
positions = []
for id in self.servo_ids:
position = self.position_reader.getData(id, ReadAttribute.POSITION.value, 4)
if position > 2**31:
position -= 2**32
positions.append(position)
return np.array(positions)
def read_velocity(self):
"""
Reads the joint velocities of the robot.
:return: list of joint velocities,
"""
self.velocity_reader.txRxPacket()
velocties = []
for id in self.servo_ids:
velocity = self.velocity_reader.getData(id, ReadAttribute.VELOCITY.value, 4)
if velocity > 2**31:
velocity -= 2**32
velocties.append(velocity)
return np.array(velocties)
def set_goal_pos(self, action):
"""
:param action: list or numpy array of target joint positions in range [0, 4096[
"""
if self.motor_control_state is not MotorControlType.POSITION_CONTROL:
self._set_position_control()
for i, motor_id in enumerate(self.servo_ids):
data_write = [
DXL_LOBYTE(DXL_LOWORD(action[i])),
DXL_HIBYTE(DXL_LOWORD(action[i])),
DXL_LOBYTE(DXL_HIWORD(action[i])),
DXL_HIBYTE(DXL_HIWORD(action[i])),
]
self.pos_writer.changeParam(motor_id, data_write)
self.pos_writer.txPacket()
def set_pwm(self, action):
"""
Sets the pwm values for the servos.
:param action: list or numpy array of pwm values in range [0, 885]
"""
if self.motor_control_state is not MotorControlType.PWM:
self._set_pwm_control()
for i, motor_id in enumerate(self.servo_ids):
data_write = [
DXL_LOBYTE(DXL_LOWORD(action[i])),
DXL_HIBYTE(DXL_LOWORD(action[i])),
]
self.pwm_writer.changeParam(motor_id, data_write)
self.pwm_writer.txPacket()
def set_trigger_torque(self, torque: int):
"""
Sets a constant torque torque for the last servo in the chain. This is useful for the trigger of the leader arm
"""
self.dynamixel._enable_torque(self.servo_ids[-1])
self.dynamixel.set_pwm_value(self.servo_ids[-1], torque)
def limit_pwm(self, limit: Union[int, list, np.ndarray]):
"""
Limits the pwm values for the servos in for position control
@param limit: 0 ~ 885
@return:
"""
if isinstance(limit, int):
limits = [
limit,
] * 5
else:
limits = limit
self._disable_torque()
for motor_id, limit in zip(self.servo_ids, limits, strict=False):
self.dynamixel.set_pwm_limit(motor_id, limit)
self._enable_torque()
def _disable_torque(self):
print(f"disabling torque for servos {self.servo_ids}")
for motor_id in self.servo_ids:
self.dynamixel._disable_torque(motor_id)
def _enable_torque(self):
print(f"enabling torque for servos {self.servo_ids}")
for motor_id in self.servo_ids:
self.dynamixel._enable_torque(motor_id)
def _set_pwm_control(self):
self._disable_torque()
for motor_id in self.servo_ids:
self.dynamixel.set_operating_mode(motor_id, OperatingMode.PWM)
self._enable_torque()
self.motor_control_state = MotorControlType.PWM
def _set_position_control(self):
self._disable_torque()
for motor_id in self.servo_ids:
# TODO(rcadene): redesign
if motor_id == 9:
self.dynamixel.set_operating_mode(9, OperatingMode.CURRENT_CONTROLLED_POSITION)
else:
self.dynamixel.set_operating_mode(motor_id, OperatingMode.POSITION)
self._enable_torque()
self.motor_control_state = MotorControlType.POSITION_CONTROL

View File

@@ -1,237 +0,0 @@
"""This script demonstrates how to record a LeRobot dataset of training data
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
"""
import argparse
import copy
import os
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym
import numpy as np
import torch
from datasets import Dataset, Features, Sequence, Value
from omegaconf import OmegaConf
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
# parse the repo_id name via command line
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
parser.add_argument("--num-episodes", type=int, default=2)
parser.add_argument("--num-frames", type=int, default=400)
parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
parser.add_argument(
"--fps_tolerance",
type=float,
default=0.5,
help="Tolerance in fps for the recording before dropping episodes.",
)
parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
)
parser.add_argument("--gym-config", type=str, default=None, help="Path to the gym config file.")
parser.add_argument("--mock_robot", action="store_true")
args = parser.parse_args()
repo_id = args.repo_id
num_episodes = args.num_episodes
num_frames = args.num_frames
revision = args.revision
fps = args.fps
fps_tolerance = args.fps_tolerance
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
# During data collection, frames are stored as png images in `images_dir`
images_dir = out_data / "images"
# After data collection, png images of each episode are encoded into a mp4 file stored in `videos_dir`
videos_dir = out_data / "videos"
meta_data_dir = out_data / "meta_data"
gym_config = None
if args.config is not None:
gym_config = OmegaConf.load(args.config)
# Create image and video directories
if not os.path.exists(images_dir):
os.makedirs(images_dir, exist_ok=True)
if not os.path.exists(videos_dir):
os.makedirs(videos_dir, exist_ok=True)
if __name__ == "__main__":
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
gym_handle = "gym_real_world/RealEnv-v0"
gym_kwargs = {}
if gym_config is not None:
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
env = gym.make(
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
ep_fps = []
id_from = 0
id_to = 0
os.system('spd-say "gym environment created"')
ep_idx = 0
while ep_idx < num_episodes:
# bring the follower to the leader and start camera
env.reset()
os.system(f'spd-say "go {ep_idx}"')
# init buffers
obs_replay = {k: [] for k in env.observation_space}
drop_episode = False
timestamps = []
for _ in tqdm(range(num_frames)):
# Apply the next action
observation, _, _, _, info = env.step(action=None)
# images_stacked = np.hstack(list(observation['pixels'].values()))
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
# cv2.imshow('frame', images_stacked)
if info["fps_error"]:
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
drop_episode = True
break
# store data
for key in observation:
obs_replay[key].append(copy.deepcopy(observation[key]))
timestamps.append(info["timestamp"])
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
os.system('spd-say "stop"')
if not drop_episode:
os.system(f'spd-say "saving episode {ep_idx}"')
ep_dict = {}
# store images in png and create the video
for img_key in env.cameras:
save_images_concurrently(
obs_replay[img_key],
images_dir / f"{img_key}_episode_{ep_idx:06d}",
args.num_workers,
)
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
# store the reference to the video frame
ep_dict[f"observation.{img_key}"] = [
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
]
state = torch.tensor(np.array(obs_replay["agent_pos"]))
action = torch.tensor(np.array(obs_replay["leader_pos"]))
next_done = torch.zeros(num_frames, dtype=torch.bool)
next_done[-1] = True
ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.tensor(timestamps)
ep_dict["next.done"] = next_done
ep_fps.append(num_frames / timestamps[-1])
ep_dicts.append(ep_dict)
print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}")
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(
id_from + num_frames if args.keep_last else id_from + num_frames - 1
)
id_to = id_from + num_frames if args.keep_last else id_from + num_frames - 1
id_from = id_to
ep_idx += 1
env.close()
os.system('spd-say "encode video frames"')
for ep_idx in range(num_episodes):
for img_key in env.cameras:
# If necessary, we may want to encode the video
# with variable frame rate: https://superuser.com/questions/1661901/encoding-video-from-vfr-still-images
encode_video_frames(
images_dir / f"{img_key}_episode_{ep_idx:06d}",
videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4",
ep_fps[ep_idx],
)
os.system('spd-say "concatenate episodes"')
data_dict = concatenate_episodes(
ep_dicts, drop_episodes_last_frame=not args.keep_last
) # Since our fps varies we are sometimes off tolerance for the last frame
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
features[key] = VideoFrame()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": sum(ep_fps) / len(ep_fps), # to have a good tolerance in data processing for the slowest video
"video": 1,
}
os.system('spd-say "from preloaded"')
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
version=revision,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
os.system('spd-say "compute stats"')
stats = compute_stats(lerobot_dataset)
os.system('spd-say "save to disk"')
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_data / "train"))
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if args.push_to_hub:
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
push_videos_to_hub(repo_id, videos_dir, revision="main")
push_videos_to_hub(repo_id, videos_dir, revision=revision)

View File

@@ -1,60 +0,0 @@
import argparse
import logging
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym # noqa: F401
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.eval import eval
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
try:
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)

View File

@@ -1,103 +0,0 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: ???
override_dataset_stats:
observation.images.high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 1000
online_steps: 0
eval_freq: -1
save_freq: 1000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.high: [3, 480, 640]
observation.images.low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.high: mean_std
observation.images.low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -1,103 +0,0 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# front) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: ???
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.front:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 1000
online_steps: 0
eval_freq: -1
save_freq: 1000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.images.front: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.top: mean_std
observation.images.front: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

1
gym_dora/README.md Normal file
View File

@@ -0,0 +1 @@
# gym_dora

17
gym_dora/example.py Normal file
View File

@@ -0,0 +1,17 @@
import gymnasium as gym
import gym_dora # noqa: F401
env = gym.make("gym_dora/DoraAloha-v0", disable_env_checker=True)
obs = env.reset()
policy = ... # make_policy
done = False
while not done:
actions = policy.select_action(obs)
observation, reward, terminated, truncated, info = env.step(actions)
done = terminated | truncated | done
env.close()

View File

@@ -0,0 +1,17 @@
from gymnasium.envs.registration import register
register(
id="gym_dora/DoraAloha-v0",
entry_point="gym_dora.env:DoraEnv",
max_episode_steps=300,
nondeterministic=True,
kwargs={"model": "aloha"},
)
register(
id="gym_dora/DoraKoch-v0",
entry_point="gym_dora.env:DoraEnv",
max_episode_steps=300,
nondeterministic=True,
kwargs={"model": "koch"},
)

199
gym_dora/gym_dora/env.py Normal file
View File

@@ -0,0 +1,199 @@
import os
import gymnasium as gym
import numpy as np
import pyarrow as pa
from dora import Node
from gymnasium import spaces
FPS = int(os.getenv("FPS", "30"))
IMAGE_WIDTH = int(os.getenv("IMAGE_WIDTH", "640"))
IMAGE_HEIGHT = int(os.getenv("IMAGE_HEIGHT", "480"))
ALOHA_JOINTS = [
# absolute joint position
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"left_arm_gripper",
# absolute joint position
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"right_arm_gripper",
]
ALOHA_ACTIONS = [
# position and quaternion for end effector
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"left_arm_gripper",
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"right_arm_gripper",
]
class DoraEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": FPS}
def __init__(
self,
model="aloha",
observation_width=IMAGE_WIDTH,
observation_height=IMAGE_HEIGHT,
cameras_names=None,
num_joints=None,
num_actions=None,
):
"""Initializes the Dora environment.
Args:
model (str): The model to use. Either 'aloha' or 'custom'.
observation_width (int): The width of the observation image.
observation_height (int): The height of the observation image.
cameras_names (list): A list of camera names to use. If not provided, the default is ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'].
num_joints (int): The number of joints in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
num_actions (int): The number of actions in the model. If not provided, the default is 14 for 'aloha' and 6 for 'fivedof'.
"""
super().__init__()
# Initialize a new node
self.node = Node() if os.environ.get("DORA_NODE_CONFIG", None) is not None else None
self.observation = {"pixels": {}, "agent_pos": None}
self.terminated = False
self.observation_height = observation_height
self.observation_width = observation_width
# Observation space
if model == "aloha":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(
{
"cam_high": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_low": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_left_wrist": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
"cam_right_wrist": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
),
}
),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(len(ALOHA_JOINTS),),
dtype=np.float64,
),
}
)
elif model == "custom":
pixel_dict = {}
for camera in cameras_names:
assert camera.startswith("cam"), "Camera names must start with 'cam'"
pixel_dict[camera] = spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(pixel_dict),
"agent_pos": spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
),
}
)
else:
raise ValueError("Model must be either 'aloha' or 'custom'.")
# Action space
if model == "aloha":
self.action_space = spaces.Box(low=-1, high=1, shape=(len(ALOHA_ACTIONS),), dtype=np.float32)
elif model == "custom":
self.action_space = spaces.Box(low=-1, high=1, shape=(num_actions,), dtype=np.float32)
def _get_obs(self):
while True:
event = self.node.next(timeout=0.001)
## If event is None, the node event stream is closed and we should terminate the env
if event is None:
self.terminated = True
break
if event["type"] == "INPUT":
# Map Image input into pixels key within Aloha environment
if "cam" in event["id"]:
self.observation["pixels"][event["id"]] = (
event["value"].to_numpy().reshape(self.observation_height, self.observation_width, 3)
)
else:
# Map other inputs into the observation dictionary using the event id as key
self.observation[event["id"]] = event["value"].to_numpy()
# If the event is a timeout error break the update loop.
elif event["type"] == "ERROR":
break
def reset(self, seed: int | None = None):
self.node.send_output("reset")
self._get_obs()
self.terminated = False
info = {}
return self.observation, info
def step(self, action: np.ndarray):
# Send the action to the dataflow as action key.
self.node.send_output("action", pa.array(action))
self._get_obs()
reward = 0
terminated = truncated = self.terminated
info = {}
return self.observation, reward, terminated, truncated, info
def render(self): ...
def close(self):
# Drop the node
del self.node

182
gym_dora/poetry.lock generated Normal file
View File

@@ -0,0 +1,182 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "cloudpickle"
version = "3.0.0"
description = "Pickler class to extend the standard pickle.Pickler functionality"
optional = false
python-versions = ">=3.8"
files = [
{file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"},
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
]
[[package]]
name = "dora-rs"
version = "0.3.4"
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
optional = false
python-versions = "*"
files = [
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
]
[package.dependencies]
pyarrow = "*"
[[package]]
name = "farama-notifications"
version = "0.0.4"
description = "Notifications for all Farama Foundation maintained libraries."
optional = false
python-versions = "*"
files = [
{file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"},
{file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"},
]
[[package]]
name = "gymnasium"
version = "0.29.1"
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
optional = false
python-versions = ">=3.8"
files = [
{file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"},
{file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"},
]
[package.dependencies]
cloudpickle = ">=1.2.0"
farama-notifications = ">=0.0.1"
numpy = ">=1.21.0"
typing-extensions = ">=4.3.0"
[package.extras]
accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"]
all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"]
atari = ["shimmy[atari] (>=0.1.0,<1.0)"]
box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"]
classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"]
mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"]
mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"]
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"]
testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"]
toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
[[package]]
name = "numpy"
version = "1.26.4"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
[[package]]
name = "pyarrow"
version = "16.1.0"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"},
{file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"},
{file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"},
{file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"},
{file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"},
{file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"},
{file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"},
{file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"},
{file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"},
{file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"},
{file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"},
{file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"},
{file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"},
{file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"},
{file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"},
{file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"},
{file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"},
]
[package.dependencies]
numpy = ">=1.16.6"
[[package]]
name = "typing-extensions"
version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "7e437b5c547ebe11095f1ce4ff1851d636f8e707ad7de8a6224b0f9ad978240f"

17
gym_dora/pyproject.toml Normal file
View File

@@ -0,0 +1,17 @@
[tool.poetry]
name = "gym-dora"
version = "0.1.0"
description = ""
authors = ["Simon Alibert <alibert.sim@gmail.com>"]
readme = "README.md"
packages = [{ include = "gym_dora" }]
[tool.poetry.dependencies]
python = "^3.10"
gymnasium = ">=0.29.1"
dora-rs = ">=0.3.4"
pyarrow = ">=12.0.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -45,9 +45,6 @@ import itertools
from lerobot.__version__ import __version__ # noqa: F401
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
# a yaml file AND a environment name. The difference should be more obvious.
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
@@ -55,7 +52,6 @@ available_tasks_per_env = {
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
}
available_envs = list(available_tasks_per_env.keys())
@@ -81,23 +77,6 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
"dora_aloha_real": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
],
}
available_real_world_datasets = [
@@ -129,19 +108,16 @@ available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
)
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
available_policies = [
"act",
"diffusion",
"tdmpc",
]
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]

View File

@@ -16,73 +16,34 @@
import logging
import torch
from omegaconf import ListConfig, OmegaConf
from omegaconf import OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def resolve_delta_timestamps(cfg):
"""Resolves delta_timestamps config key (in-place) by using `eval`.
def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by
the data type of its values).
"""
delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
# TODO(rcadene, alexander-soare): remove `eval` to avoid exploit
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
split: Select the data subset used to create an instance of LeRobotDataset.
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
Returns:
The LeRobotDataset.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if not cfg.env.real_world:
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else:
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
for dataset_repo_id in dataset_repo_ids:
if cfg.env.name not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
resolve_delta_timestamps(cfg)
delta_timestamps[key] = eval(delta_timestamps[key])
# TODO(rcadene): add data augmentations
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
)
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=delta_timestamps,
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():

View File

@@ -13,16 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import Callable
import datasets
import torch
import torch.utils
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index,
@@ -46,7 +42,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -175,7 +171,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@classmethod
def from_preloaded(
cls,
repo_id: str = "from_preloaded",
repo_id: str,
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
@@ -187,15 +183,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
stats=None,
info=None,
videos_dir=None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
on the filesystem or uploading to the hub.
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
meaningless depending on the downstream usage of the return dataset.
"""
):
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
@@ -207,193 +195,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
obj.stats = stats
obj.info = info if info is not None else {}
obj.info = info
obj.videos_dir = videos_dir
return obj
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.repo_ids = repo_ids
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
version=version,
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
)
for repo_id in repo_ids
]
# Check that some properties are consistent across datasets. Note: We may relax some of these
# consistency requirements in future iterations of this class.
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
if dataset.info != self._datasets[0].info:
raise ValueError(
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported."
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_data_keys = set()
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
for dataset in self._datasets:
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
if len(intersection_data_keys) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. The "
"multi-dataset functionality currently only keeps common keys."
)
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_samples(self) -> int:
"""Number of samples/frames."""
return sum(d.num_samples for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_samples
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_samples:
start_idx += dataset.num_samples
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f")"
)

View File

@@ -18,7 +18,6 @@ Contains utilities to process raw data format from dora-record
"""
import logging
import re
from pathlib import Path
import pandas as pd
@@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool:
return True
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
def load_from_raw(raw_dir: Path, out_dir: Path):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
@@ -57,45 +56,24 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
key = path.stem # action or observation.state or ...
if key == reference_key:
continue
if "failed_episode_index" in key:
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
continue
modality_df = pd.read_parquet(path)
modality_df = modality_df[["timestamp_utc", key]]
df = pd.merge_asof(
df,
modality_df,
on="timestamp_utc",
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart.
direction="nearest",
tolerance=pd.Timedelta(f"{1/fps} seconds"),
direction="backward",
)
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
# Remove rows with a NaN in any column. It can happened during the first frames of an episode,
# because some cameras didnt start recording yet.
df = df.dropna(axis=0)
# Remove rows with episode_index -1 which indicates a failed episode
df = df[df["episode_index"] != -1]
image_keys = [key for key in df if "observation.images." in key]
def get_episode_index(row):
episode_index_per_cam = {}
for key in image_keys:
path = row[key][0]["path"]
match = re.search(r"_(\d{6}).mp4", path)
if not match:
raise ValueError(path)
episode_index = int(match.group(1))
episode_index_per_cam[key] = episode_index
if len(set(episode_index_per_cam.values())) != 1:
raise ValueError(
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
)
return episode_index
df["episode_index"] = df.apply(get_episode_index, axis=1)
# dora only use arrays, so single values are encapsulated into a list
df["episode_index"] = df["episode_index"].map(lambda x: x[0])
df["frame_index"] = df.groupby("episode_index").cumcount()
df = df.reset_index()
df["index"] = df.index
@@ -110,16 +88,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
del df["timestamp_utc"]
# sanity check
has_nan = df.isna().any().any()
if has_nan:
raise ValueError("Dataset contains Nan values.")
# sanity check episode indices go from 0 to n-1
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids:
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
assert ep_ids == expected_ep_ids, f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
# Create symlink to raw videos directory (that needs to be absolute not relative)
out_dir.mkdir(parents=True, exist_ok=True)
@@ -132,8 +104,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
continue
for ep_idx in ep_ids:
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
if not video_path.exists():
raise ValueError(f"Video file not found in {video_path}")
assert video_path.exists(), f"Video file not found in {video_path}"
data_dict = {}
for key in df:
@@ -145,8 +116,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
# sanity check the video path is well formated
video_path = videos_dir.parent / data_dict[key][0]["path"]
if not video_path.exists():
raise ValueError(f"Video file not found in {video_path}")
assert video_path.exists(), f"Video file not found in {video_path}"
# is number
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
data_dict[key] = torch.from_numpy(df[key].values)
@@ -220,7 +190,7 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
if not video:
raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
data_df, episode_data_index = load_from_raw(raw_dir, out_dir)
hf_dataset = to_hf_dataset(data_df, video)
info = {

View File

@@ -43,6 +43,9 @@ def get_cameras(hdf5_data):
def check_format(raw_dir) -> bool:
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
assert len(hdf5_paths) != 0
for hdf5_path in hdf5_paths:
@@ -59,15 +62,17 @@ def check_format(raw_dir) -> bool:
for camera in get_cameras(data):
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
# ndim 2 when image are compressed and 4 when uncompressed
assert data[f"/observations/images/{camera}"].ndim in [2, 4]
if data[f"/observations/images/{camera}"].ndim == 4:
if compressed_images:
assert data[f"/observations/images/{camera}"].ndim == 2
else:
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(raw_dir, out_dir, fps, video, debug):
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
hdf5_files = list(raw_dir.glob("*.hdf5"))
ep_dicts = []
@@ -94,7 +99,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
for camera in get_cameras(ep):
img_key = f"observation.images.{camera}"
if ep[f"/observations/images/{camera}"].ndim == 2:
if compressed_images:
import cv2
# load one compressed image after the other in RAM and uncompress

View File

@@ -16,15 +16,17 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset, num_workers=0):
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are in channel first format
@@ -64,8 +66,9 @@ def get_stats_einops_patterns(dataset, num_workers=0):
return stats_patterns
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
def compute_stats(
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
):
if max_num_samples is None:
max_num_samples = len(dataset)
@@ -156,54 +159,3 @@ def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
"min": min[key],
}
return stats
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
)
return stats

View File

@@ -21,24 +21,19 @@ import PIL
import torch
def concatenate_episodes(ep_dicts, drop_episodes_last_frame=False):
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
if drop_episodes_last_frame:
data_dict[key] = torch.cat([ep_dict[key][:-1] for ep_dict in ep_dicts])
else:
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
if drop_episodes_last_frame:
data_dict[key].pop()
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)

View File

@@ -1,61 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterator, Union
import torch
class EpisodeAwareSampler:
def __init__(
self,
episode_data_index: dict,
episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
):
"""Sampler that optionally incorporates episode boundary information.
Args:
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
self.indices = indices
self.shuffle = shuffle
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
yield i
def __len__(self) -> int:
return len(self.indices)

View File

@@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
return outdict
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
def hf_transform_to_torch(items_dict):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
@@ -73,8 +73,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
@@ -320,7 +318,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""Reset the `episode_index` of the provided HuggingFace Dataset.
"""
Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
@@ -339,7 +338,6 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
return hf_dataset

View File

@@ -27,6 +27,14 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if n_envs is not None and n_envs < 1:
raise ValueError("`n_envs must be at least 1")
kwargs = {
# "obs_type": "pixels_agent_pos",
# "render_mode": "rgb_array",
"max_episode_steps": cfg.env.episode_length,
# "visualization_width": 384,
# "visualization_height": 384,
}
package_name = f"gym_{cfg.env.name}"
try:
@@ -38,16 +46,12 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
raise e
gym_handle = f"{package_name}/{cfg.env.task}"
gym_kwgs = dict(cfg.env.get("gym", {}))
if cfg.env.get("episode_length"):
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
]
)

View File

@@ -29,12 +29,10 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations and isinstance(observations["pixels"], dict):
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
imgs = {"observation.image": observations["pixels"]}
else:
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
imgs = {"observation.image": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)

View File

@@ -13,33 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
# TODO(rcadene, alexander-soare): clean this file
"""
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
import logging
import os
import re
from glob import glob
from pathlib import Path
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from omegaconf import OmegaConf
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
def cfg_to_group(cfg, return_list=False):
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.name}",
@@ -50,54 +42,22 @@ def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str:
return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
class Logger:
"""Primary logger object. Logs either locally or using wandb.
"""Primary logger object. Logs either locally or using wandb."""
The logger creates the following directory structure:
provided_log_dir
├── .hydra # hydra's configuration cache
├── checkpoints
├── specific_checkpoint_name
│ ├── pretrained_model # Hugging Face pretrained model directory
│ │ ├── ...
│ └── training_state.pth # optimizer, scheduler, and random states + training step
| ├── another_specific_checkpoint_name
│ │ ├── ...
| ├── ...
│ └── last # a softlink to the last logged checkpoint
"""
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
job_name: The WandB job name.
"""
self._cfg = cfg
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.checkpoints_dir = self.get_checkpoints_dir(log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir)
# Set up WandB.
def __init__(self, log_dir, job_name, cfg):
self._log_dir = Path(log_dir)
self._log_dir.mkdir(parents=True, exist_ok=True)
self._job_name = job_name
self._model_dir = self._log_dir / "checkpoints"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.training.get("save_buffer", False)
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg
self._eval = []
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
@@ -109,127 +69,65 @@ class Logger:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
wandb.init(
id=wandb_run_id,
project=project,
entity=entity,
name=wandb_job_name,
name=job_name,
notes=cfg.get("wandb", {}).get("notes"),
# group=self._group,
tags=cfg_to_group(cfg, return_list=True),
dir=log_dir,
dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
# TODO(rcadene): add resume option
resume=None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@classmethod
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
return Path(log_dir) / "checkpoints"
def save_model(self, policy: Policy, identifier):
if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True)
save_dir = self._model_dir / str(identifier)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
@classmethod
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
return cls.get_checkpoints_dir(log_dir) / "last"
@classmethod
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
"""
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
be saved.
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
Optionally also upload the model to WandB.
"""
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
policy.save_pretrained(save_dir)
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._cfg.wandb.disable_artifact:
def save_buffer(self, buffer, identifier):
self._buffer_dir.mkdir(parents=True, exist_ok=True)
fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp)
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
if self.last_checkpoint_dir.exists():
os.remove(self.last_checkpoint_dir)
def save_training_state(
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer,
scheduler: LRScheduler | None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
def save_checkpont(
self,
train_step: int,
policy: Policy,
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
wandb_artifact_name = (
None
if self._wandb is None
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError(
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
artifact = self._wandb.Artifact(
f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="buffer",
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
def finish(self, agent, buffer):
if self._save_model:
self.save_model(agent, identifier="final")
if self._save_buffer:
self.save_buffer(buffer, identifier="buffer")
if self._wandb:
self._wandb.finish()
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):

View File

@@ -25,13 +25,6 @@ class ACTConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
@@ -40,15 +33,15 @@ class ACTConfig:
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesn't include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
@@ -129,9 +122,7 @@ class ACTConfig:
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
# As a consequence we also remove the final, unused layer normalization, by default
n_decoder_layers: int = 1
decoder_norm: bool = False
# VAE.
use_vae: bool = True
latent_dim: int = 32

View File

@@ -198,31 +198,27 @@ class ACT(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_input_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
config.output_shapes["action"][0], config.dim_model
config.input_shapes["observation.state"][0], config.dim_model
)
self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
)
# Backbone for image feature extraction.
@@ -242,17 +238,15 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
if self.use_input_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.use_input_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
@@ -291,7 +285,7 @@ class ACT(nn.Module):
"action" in batch
), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.images"].shape[0]
batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
@@ -299,43 +293,31 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_input_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
1
) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_input_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
# Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_joint_is_pad = torch.full((batch_size, 2), False).to(
batch["observation.state"].device
) # False: not a padding
key_padding_mask = torch.cat([cls_joint_is_pad, batch["action_is_pad"]], axis=1) # (bs, seq+1)
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim]
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
@@ -344,10 +326,8 @@ class ACT(nn.Module):
all_cam_features = []
all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
@@ -357,15 +337,13 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
if self.use_input_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack(encoder_in_feats, axis=0),
torch.stack([latent_embed, robot_state_embed], axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
@@ -379,7 +357,6 @@ class ACT(nn.Module):
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,
@@ -408,11 +385,9 @@ class ACTEncoder(nn.Module):
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
x = layer(x, pos_embed=pos_embed)
x = self.norm(x)
return x
@@ -435,14 +410,12 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)[
0
] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@@ -462,10 +435,7 @@ class ACTDecoder(nn.Module):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
if config.decoder_norm:
self.norm = nn.LayerNorm(config.dim_model)
else:
self.norm = nn.Identity()
self.norm = nn.LayerNorm(config.dim_model)
def forward(
self,
@@ -478,7 +448,8 @@ class ACTDecoder(nn.Module):
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
x = self.norm(x)
if self.norm is not None:
x = self.norm(x)
return x

View File

@@ -26,26 +26,21 @@ class DiffusionConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- A key starting with "observation.image is required as an input.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
@@ -160,7 +155,7 @@ class DiffusionConfig:
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.crop_shape is not None and (
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):

View File

@@ -239,8 +239,10 @@ class DiffusionModel(nn.Module):
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond)
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps
@@ -302,11 +304,7 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
)
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
@@ -425,15 +423,11 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
# use the height and width from `config.crop_shape`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])

View File

@@ -147,7 +147,7 @@ class Normalize(nn.Module):
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:

View File

@@ -31,15 +31,6 @@ class TDMPCConfig:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a

View File

@@ -1,325 +0,0 @@
import argparse
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import time
import traceback
import cv2
import numpy as np
import pyrealsense2 as rs
from lerobot.common.robot_devices.cameras.opencv import find_camera_indices
from lerobot.common.robot_devices.cameras.utils import save_color_image, save_depth_image
SERIAL_NUMBER_INDEX = 1
def find_camera_indices(raise_when_empty=True):
camera_ids = []
for device in rs.context().query_devices():
serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX)))
camera_ids.append(serial_number)
if raise_when_empty and len(camera_ids) == 0:
raise OSError("Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware.")
return camera_ids
def benchmark_cameras(cameras, out_dir=None, save_images=False):
if save_images:
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
while True:
now = time.time()
for camera in cameras:
if camera.use_depth:
color_image, depth_image = camera.capture_image("bgr" if save_images else "rgb")
else:
color_image = camera.capture_image("bgr" if save_images else "rgb")
if save_images:
image_path = out_dir / f"camera_{camera.camera_index:02}.png"
print(f"Write to {image_path}")
save_color_image(color_image, image_path, write_shape=True)
if camera.use_depth:
# Apply colormap on depth image (image must be converted to 8-bit per pixel first)
depth_image_path = out_dir / f"camera_{camera.camera_index:02}_depth.png"
print(f"Write to {depth_image_path}")
save_depth_image(depth_image_path, depth_image, write_shape=True)
dt_s = (time.time() - now)
dt_ms = dt_s * 1000
freq = 1 / dt_s
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
if save_images:
break
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# Pre-defined configs that worked
@dataclass
class IntelRealSenseCameraConfig:
"""
Example of tested options for Intel Real Sense D405:
```python
IntelRealSenseCameraConfig(30, 640, 480)
IntelRealSenseCameraConfig(60, 640, 480)
IntelRealSenseCameraConfig(90, 640, 480)
IntelRealSenseCameraConfig(30, 1280, 720)
IntelRealSenseCameraConfig(30, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(60, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(90, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(30, 1280, 720, use_depth=True)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color: str = "rgb"
use_depth: bool = False
force_hardware_reset: bool = True
class IntelRealSenseCamera():
# TODO(rcadene): improve dosctring
"""
Using this class requires:
- [installing `librealsense` and its python wrapper `pyrealsense2`](https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md)
- [updating the camera(s) firmware](https://dev.intelrealsense.com/docs/firmware-releases-d400)
Example of getting the `camera_index` for your camera(s):
```bash
rs-fw-update -l
> Connected devices:
> 1) [USB] Intel RealSense D405 s/n 128422270109, update serial number: 133323070634, firmware version: 5.16.0.1
> 2) [USB] Intel RealSense D405 s/n 128422271609, update serial number: 130523070758, firmware version: 5.16.0.1
> 3) [USB] Intel RealSense D405 s/n 128422271614, update serial number: 133323070576, firmware version: 5.16.0.1
> 4) [USB] Intel RealSense D405 s/n 128422271393, update serial number: 133323070271, firmware version: 5.16.0.1
```
Example of uage:
```python
camera = IntelRealSenseCamera(128422270109) # serial number (s/n)
color_image = camera.capture_image()
```
Example of capturing additional depth image:
```python
config = IntelRealSenseCameraConfig(use_depth=True)
camera = IntelRealSenseCamera(128422270109, config)
color_image, depth_image = camera.capture_image()
```
"""
AVAILABLE_CAMERA_INDICES = find_camera_indices()
def __init__(self,
camera_index: int | None = None,
config: IntelRealSenseCameraConfig | None = None,
**kwargs,
):
if config is None:
config = IntelRealSenseCameraConfig()
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.camera_index = camera_index
self.fps = config.fps
self.width = config.width
self.height = config.height
self.color = config.color
self.use_depth = config.use_depth
self.force_hardware_reset = config.force_hardware_reset
# TODO(rcadene): move these two check in config dataclass
if self.color not in ["rgb", "bgr"]:
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
if (self.fps or self.width or self.height) and not (self.fps and self.width and self.height):
raise ValueError(f"Expected all fps, width and height to be set, when one of them is set, but {self.fps=}, {self.width=}, {self.height=}.")
if self.camera_index is None:
raise ValueError(f"`camera_index` is expected to be a serial number of one of these available cameras ({IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES}), but {camera_index} is provided instead.")
self.camera = None
self.is_connected = False
self.t = Thread(target=self.capture_image_loop, args=())
self.t.daemon = True
self._color_image = None
def connect(self):
if self.is_connected:
raise ValueError(f"Camera {self.camera_index} is already connected.")
config = rs.config()
config.enable_device(str(self.camera_index))
if self.fps and self.width and self.height:
# TODO(rcadene): can we set rgb8 directly?
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
else:
config.enable_stream(rs.stream.color)
if self.use_depth:
if self.fps and self.width and self.height:
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
else:
config.enable_stream(rs.stream.depth)
self.camera = rs.pipeline()
try:
self.camera.start(config)
except RuntimeError:
# Verify that the provided `camera_index` is valid before printing the traceback
if self.camera_index not in IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES:
raise ValueError(f"`camera_index` is expected to be a serial number of one of these available cameras {IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES}, but {self.camera_index} is provided instead.")
traceback.print_exc()
self.is_connected = True
self.t.start()
def capture_image(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
frame = self.camera.wait_for_frames()
color_frame = frame.get_color_frame()
if not color_frame:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
color_image = np.asanyarray(color_frame.get_data())
if temporary_color is None:
requested_color = self.color
else:
requested_color = temporary_color
if requested_color not in ["rgb", "bgr"]:
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
# OpenCV uses BGR format as default (blue, green red) for all operations, including displaying images.
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
# so we convert the image color from BGR to RGB.
# if requested_color == "rgb":
# color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
if self.use_depth:
depth_frame = frame.get_depth_frame()
if not depth_frame:
raise OSError(f"Can't capture depth image from camera {self.camera_index}.")
depth_image = np.asanyarray(depth_frame.get_data())
return color_image, depth_image
else:
return color_image
def capture_image_loop(self):
while True:
self._color_image = self.capture_image()
def read(self):
while self._color_image is None:
time.sleep(0.1)
return self._color_image
def disconnect(self):
if getattr(self, "camera", None):
try:
self.camera.stop()
except RuntimeError as e:
if "stop() cannot be called before start()" in str(e):
# skip this runtime error
return
traceback.print_exc()
def __del__(self):
self.disconnect()
def save_images_config(config, out_dir: Path):
camera_ids = IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES
cameras = []
print(f"Available camera indices: {camera_ids}")
for camera_idx in camera_ids:
camera = IntelRealSenseCamera(camera_idx, config)
cameras.append(camera)
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}_depth_{config.use_depth}"
benchmark_cameras(cameras, out_dir, save_images=True)
def benchmark_config(config, camera_ids: list[int]):
cameras = [IntelRealSenseCamera(idx, config) for idx in camera_ids]
benchmark_cameras(cameras)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=["save_images", 'benchmark'], default="save_images")
parser.add_argument("--camera-ids", type=int, nargs="*", default=[128422271609, 128422271614, 128422271393])
parser.add_argument("--fps", type=int, default=30)
parser.add_argument("--width", type=str, default=640)
parser.add_argument("--height", type=str, default=480)
parser.add_argument("--use-depth", type=int, default=0)
parser.add_argument("--out-dir", type=Path, default="outputs/benchmark_cameras/intelrealsense/2024_06_22_1738")
args = parser.parse_args()
config = IntelRealSenseCameraConfig(args.fps, args.width, args.height, use_depth=bool(args.use_depth))
# config = IntelRealSenseCameraConfig()
# config = IntelRealSenseCameraConfig(60, 640, 480)
# config = IntelRealSenseCameraConfig(90, 640, 480)
# config = IntelRealSenseCameraConfig(30, 1280, 720)
if args.mode == "save_images":
save_images_config(config, args.out_dir)
elif args.mode == "benchmark":
benchmark_config(config, args.camera_ids)
else:
raise ValueError(args.mode)
# if __name__ == "__main__":
# # Works well!
# # use_depth = False
# # fps = 90
# # width = 640
# # height = 480
# # # Works well!
# # use_depth = True
# # fps = 90
# # width = 640
# # height = 480
# # # Doesn't work well, latency varies too much
# # use_depth = True
# # fps = 30
# # width = 1280
# # height = 720
# # Works well
# use_depth = False
# fps = 30
# width = 1280
# height = 720
# config = IntelRealSenseCameraConfig()
# # config = IntelRealSenseCameraConfig(fps, width, height, use_depth=use_depth)
# cameras = [
# # IntelRealSenseCamera(0, config),
# # IntelRealSenseCamera(128422270109, config),
# IntelRealSenseCamera(128422271609, config),
# IntelRealSenseCamera(128422271614, config),
# IntelRealSenseCamera(128422271393, config),
# ]
# out_dir = "outputs/benchmark_cameras/intelrealsense/2024_06_22_1729"
# out_dir += f"{config.width}x{config.height}_{config.fps}_depth_{config.use_depth}"
# benchmark_cameras(cameras, out_dir, save_images=False)

View File

@@ -1,249 +0,0 @@
import argparse
from dataclasses import dataclass, replace
from pathlib import Path
from threading import Thread
import time
import cv2
import numpy as np
from lerobot.common.robot_devices.cameras.utils import save_color_image
def find_camera_indices(raise_when_empty=False, max_index_search_range=60):
camera_ids = []
for camera_idx in range(max_index_search_range):
camera = cv2.VideoCapture(camera_idx)
is_open = camera.isOpened()
camera.release()
if is_open:
print(f"Camera found at index {camera_idx}")
camera_ids.append(camera_idx)
if raise_when_empty and len(camera_ids) == 0:
raise OSError("Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2.")
return camera_ids
def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frames=4):
if out_dir:
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for _ in range(num_warmup_frames):
for camera in cameras:
try:
camera.capture_image()
time.sleep(0.01)
except OSError as e:
print(e)
while True:
now = time.time()
for camera in cameras:
color_image = camera.capture_image("bgr" if save_images else "rgb")
if save_images:
image_path = out_dir / f"camera_{camera.camera_index:02}.png"
print(f"Write to {image_path}")
save_color_image(color_image, image_path, write_shape=True)
dt_s = (time.time() - now)
dt_ms = dt_s * 1000
freq = 1 / dt_s
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
if save_images:
break
if cv2.waitKey(1) & 0xFF == ord("q"):
break
@dataclass
class OpenCVCameraConfig:
"""
Example of tested options for Intel Real Sense D405:
```python
OpenCVCameraConfig(30, 640, 480)
OpenCVCameraConfig(60, 640, 480)
OpenCVCameraConfig(90, 640, 480)
OpenCVCameraConfig(30, 1280, 720)
```
"""
fps: int | None = None
width: int | None = None
height: int | None = None
color: str = "rgb"
class OpenCVCamera():
# TODO(rcadene): improve dosctring
"""
https://docs.opencv.org/4.x/d0/da7/videoio_overview.html
https://docs.opencv.org/4.x/d4/d15/group__videoio__flags__base.html#ga023786be1ee68a9105bf2e48c700294d
Example of uage:
```python
camera = OpenCVCamera(2)
color_image = camera.capture_image()
```
"""
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
if config is None:
config = OpenCVCameraConfig()
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.camera_index = camera_index
self.fps = config.fps
self.width = config.width
self.height = config.height
self.color = config.color
if self.color not in ["rgb", "bgr"]:
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
if self.camera_index is None:
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead.")
self.camera = None
self.is_connected = False
self.t = Thread(target=self.capture_image_loop, args=())
self.t.daemon = True
self._color_image = None
def connect(self):
if self.is_connected:
raise ValueError(f"Camera {self.camera_index} is already connected.")
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(self.camera_index)
is_camera_open = tmp_camera.isOpened()
# Release camera to make it accessible for `find_camera_indices`
del tmp_camera
# If the camera doesn't work, display the camera indices corresponding to
# valid cameras.
if not is_camera_open:
# Verify that the provided `camera_index` is valid before printing the traceback
if self.camera_index not in find_camera_indices():
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead.")
raise OSError(f"Can't access camera {self.camera_index}.")
# Secondly, create the camera that will be used downstream.
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
# needs to be re-created.
self.camera = cv2.VideoCapture(self.camera_index)
if self.fps:
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
if self.width:
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
if self.height:
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
if self.fps and self.fps != actual_fps:
raise OSError(f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}.")
if self.width and self.width != actual_width:
raise OSError(f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}.")
if self.height and self.height != actual_height:
raise OSError(f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}.")
self.is_connected = True
self.t.start()
def capture_image(self, temporary_color: str | None = None) -> np.ndarray:
if not self.is_connected:
self.connect()
ret, color_image = self.camera.read()
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
if temporary_color is None:
requested_color = self.color
else:
requested_color = temporary_color
if requested_color not in ["rgb", "bgr"]:
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
# OpenCV uses BGR format as default (blue, green red) for all operations, including displaying images.
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
# so we convert the image color from BGR to RGB.
if requested_color == "rgb":
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
return color_image
def capture_image_loop(self):
while True:
self._color_image = self.capture_image()
def read(self):
while self._color_image is None:
time.sleep(0.1)
return self._color_image
def disconnect(self):
if getattr(self, "camera", None):
self.camera.release()
def __del__(self):
self.disconnect()
def save_images_config(config: OpenCVCameraConfig, out_dir: Path):
cameras = []
print(f"Available camera indices: {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}")
for camera_idx in OpenCVCamera.AVAILABLE_CAMERAS_INDICES:
camera = OpenCVCamera(camera_idx, config)
cameras.append(camera)
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}"
benchmark_cameras(cameras, out_dir, save_images=True)
def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
cameras = [OpenCVCamera(idx, config) for idx in camera_ids]
benchmark_cameras(cameras)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=["save_images", 'benchmark'], default="save_images")
parser.add_argument("--camera-ids", type=int, nargs="*", default=[16, 4, 22, 10])
parser.add_argument("--fps", type=int, default=30)
parser.add_argument("--width", type=str, default=640)
parser.add_argument("--height", type=str, default=480)
parser.add_argument("--out-dir", type=Path, default="outputs/benchmark_cameras/opencv/2024_06_22_1727")
args = parser.parse_args()
config = OpenCVCameraConfig(args.fps, args.width, args.height)
# config = OpenCVCameraConfig()
# config = OpenCVCameraConfig(60, 640, 480)
# config = OpenCVCameraConfig(90, 640, 480)
# config = OpenCVCameraConfig(30, 1280, 720)
if args.mode == "save_images":
save_images_config(config, args.out_dir)
elif args.mode == "benchmark":
benchmark_config(config, args.camera_ids)
else:
raise ValueError(args.mode)

View File

@@ -1,48 +0,0 @@
from pathlib import Path
import time
import cv2
from typing import Protocol
import numpy as np
def write_shape_on_image_inplace(image):
height, width = image.shape[:2]
text = f'Width: {width} Height: {height}'
# Define the font, scale, color, and thickness
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
color = (255, 0, 0) # Blue in BGR
thickness = 2
position = (10, height - 10) # 10 pixels from the bottom-left corner
cv2.putText(image, text, position, font, font_scale, color, thickness)
def save_color_image(image, path, write_shape=False):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if write_shape:
write_shape_on_image_inplace(image)
cv2.imwrite(str(path), image)
def save_depth_image(depth, path, write_shape=False):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# Apply colormap on depth image (image must be converted to 8-bit per pixel first)
depth_image = cv2.applyColorMap(cv2.convertScaleAbs(depth, alpha=0.03), cv2.COLORMAP_JET)
if write_shape:
write_shape_on_image_inplace(depth_image)
cv2.imwrite(str(path), depth_image)
# Defines a camera type
class Camera(Protocol):
def connect(self): ...
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
def disconnect(self): ...

View File

@@ -1,405 +0,0 @@
from copy import deepcopy
import enum
from typing import Union
import numpy as np
from dynamixel_sdk import PacketHandler, PortHandler, COMM_SUCCESS, GroupSyncRead, GroupSyncWrite
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD
PROTOCOL_VERSION = 2.0
BAUD_RATE = 1_000_000
TIMEOUT_MS = 1000
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350
# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270
# data_name: (address, size_byte)
X_SERIES_CONTROL_TABLE = {
"Model_Number": (0, 2),
"Model_Information": (2, 4),
"Firmware_Version": (6, 1),
"ID": (7, 1),
"Baud_Rate": (8, 1),
"Return_Delay_Time": (9, 1),
"Drive_Mode": (10, 1),
"Operating_Mode": (11, 1),
"Secondary_ID": (12, 1),
"Protocol_Type": (13, 1),
"Homing_Offset": (20, 4),
"Moving_Threshold": (24, 4),
"Temperature_Limit": (31, 1),
"Max_Voltage_Limit": (32, 2),
"Min_Voltage_Limit": (34, 2),
"PWM_Limit": (36, 2),
"Current_Limit": (38, 2),
"Acceleration_Limit": (40, 4),
"Velocity_Limit": (44, 4),
"Max_Position_Limit": (48, 4),
"Min_Position_Limit": (52, 4),
"Shutdown": (63, 1),
"Torque_Enable": (64, 1),
"LED": (65, 1),
"Status_Return_Level": (68, 1),
"Registered_Instruction": (69, 1),
"Hardware_Error_Status": (70, 1),
"Velocity_I_Gain": (76, 2),
"Velocity_P_Gain": (78, 2),
"Position_D_Gain": (80, 2),
"Position_I_Gain": (82, 2),
"Position_P_Gain": (84, 2),
"Feedforward_2nd_Gain": (88, 2),
"Feedforward_1st_Gain": (90, 2),
"Bus_Watchdog": (98, 1),
"Goal_PWM": (100, 2),
"Goal_Current": (102, 2),
"Goal_Velocity": (104, 4),
"Profile_Acceleration": (108, 4),
"Profile_Velocity": (112, 4),
"Goal_Position": (116, 4),
"Realtime_Tick": (120, 2),
"Moving": (122, 1),
"Moving_Status": (123, 1),
"Present_PWM": (124, 2),
"Present_Current": (126, 2),
"Present_Velocity": (128, 4),
"Present_Position": (132, 4),
"Velocity_Trajectory": (136, 4),
"Position_Trajectory": (140, 4),
"Present_Input_Voltage": (144, 2),
"Present_Temperature": (146, 1)
}
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
#CONVERT_POSITION_TO_ANGLE_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_POSITION_TO_ANGLE_REQUIRED = []
MODEL_CONTROL_TABLE = {
"x_series": X_SERIES_CONTROL_TABLE,
"xl330-m077": X_SERIES_CONTROL_TABLE,
"xl330-m288": X_SERIES_CONTROL_TABLE,
"xl430-w250": X_SERIES_CONTROL_TABLE,
"xm430-w350": X_SERIES_CONTROL_TABLE,
"xm540-w270": X_SERIES_CONTROL_TABLE,
}
def uint32_to_int32(values: np.ndarray):
"""
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
"""
for i in range(len(values)):
if values[i] is not None and values[i] > 2147483647:
values[i] = values[i] - 4294967296
return values
def int32_to_uint32(values: np.ndarray):
"""
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
"""
for i in range(len(values)):
if values[i] is not None and values[i] < 0:
values[i] = values[i] + 4294967296
return values
def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
"""
Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
"""
return (position / 2048) * 3.14
def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
"""
Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
"""
return ((angle / 3.14) * 2048).astype(np.int64)
# def pwm2vel(pwm: np.ndarray) -> np.ndarray:
# """
# :param pwm: numpy array of pwm/s joint velocities
# :return: numpy array of rad/s joint velocities
# """
# return pwm * 3.14 / 2048
# def vel2pwm(vel: np.ndarray) -> np.ndarray:
# """
# :param vel: numpy array of rad/s joint velocities
# :return: numpy array of pwm/s joint velocities
# """
# return (vel * 2048 / 3.14).astype(np.int64)
def get_group_sync_key(data_name, motor_names):
group_key = f"{data_name}_" + "_".join([name for name in motor_names])
return group_key
class TorqueMode(enum.Enum):
ENABLED = 1
DISABLED = 0
class OperatingMode(enum.Enum):
VELOCITY = 1
POSITION = 3
EXTENDED_POSITION = 4
CURRENT_CONTROLLED_POSITION = 5
PWM = 16
UNKNOWN = -1
class DriveMode(enum.Enum):
NON_INVERTED = 0
INVERTED = 1
class DynamixelMotorsBus:
def __init__(self, port: str, motors: dict[str, tuple[int, str]],
extra_model_control_table: dict[str, list[tuple]] | None = None):
self.port = port
self.motors = motors
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
if not self.port_handler.openPort():
raise OSError(f"Failed to open port {self.port}")
self.port_handler.setBaudRate(BAUD_RATE)
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
self.group_readers = {}
self.group_writers = {}
self.calibration = None
@property
def motor_names(self) -> list[int]:
return list(self.motors.keys())
def set_calibration(self, calibration: dict[str, tuple[int, bool]]):
self.calibration = calibration
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration:
return values
if motor_names is None:
motor_names = self.motor_names
for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name]
if values[i] is not None:
if drive_mode:
values[i] *= -1
values[i] += homing_offset
return values
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
if not self.calibration:
return values
if motor_names is None:
motor_names = self.motor_names
for i, name in enumerate(motor_names):
homing_offset, drive_mode = self.calibration[name]
if values[i] is not None:
values[i] -= homing_offset
if drive_mode:
values[i] *= -1
return values
def read(self, data_name, motor_names: list[str] | None = None):
if motor_names is None:
motor_names = self.motor_names
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
# TODO(rcadene): assert all motors follow same address
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
if data_name not in self.group_readers:
# create new group reader
self.group_readers[group_key] = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
for idx in motor_ids:
self.group_readers[group_key].addParam(idx)
comm = self.group_readers[group_key].txRxPacket()
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = self.group_readers[group_key].getData(idx, addr, bytes)
values.append(value)
values = np.array(values)
# TODO(rcadene): explain why
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
values = uint32_to_int32(values)
if data_name in CALIBRATION_REQUIRED:
values = self.apply_calibration(values, motor_names)
if data_name in CONVERT_POSITION_TO_ANGLE_REQUIRED:
values = motor_position_to_angle(values)
return values
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str):
motor_names = [motor_names]
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
if isinstance(values, (int, float, np.integer)):
values = [int(values)] * len(motor_ids)
values = np.array(values)
if data_name in CONVERT_POSITION_TO_ANGLE_REQUIRED:
values = motor_angle_to_position(values)
if data_name in CALIBRATION_REQUIRED:
values = self.revert_calibration(values, motor_names)
# TODO(rcadene): why dont we do it?
# if data_name in CONVERT_INT32_TO_UINT32_REQUIRED:
# values = int32_to_uint32(values)
values = values.tolist()
# TODO(rcadene): assert all motors follow same address
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
init_group = data_name not in self.group_readers
if init_group:
self.group_writers[group_key] = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
for idx, value in zip(motor_ids, values):
if bytes == 1:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
]
elif bytes == 2:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
]
elif bytes == 4:
data = [
DXL_LOBYTE(DXL_LOWORD(value)),
DXL_HIBYTE(DXL_LOWORD(value)),
DXL_LOBYTE(DXL_HIWORD(value)),
DXL_HIBYTE(DXL_HIWORD(value)),
]
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead.")
if init_group:
self.group_writers[group_key].addParam(idx, data)
else:
self.group_writers[group_key].changeParam(idx, data)
comm = self.group_writers[group_key].txPacket()
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
# def read(self, data_name, motor_name: str):
# motor_idx, model = self.motors[motor_name]
# addr, bytes = self.model_ctrl_table[model][data_name]
# args = (self.port_handler, motor_idx, addr)
# if bytes == 1:
# value, comm, err = self.packet_handler.read1ByteTxRx(*args)
# elif bytes == 2:
# value, comm, err = self.packet_handler.read2ByteTxRx(*args)
# elif bytes == 4:
# value, comm, err = self.packet_handler.read4ByteTxRx(*args)
# else:
# raise NotImplementedError(
# f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
# f"{bytes} is provided instead.")
# if comm != COMM_SUCCESS:
# raise ConnectionError(
# f"Read failed due to communication error on port {self.port} for motor {motor_idx}: "
# f"{self.packet_handler.getTxRxResult(comm)}"
# )
# elif err != 0:
# raise ConnectionError(
# f"Read failed due to error {err} on port {self.port} for motor {motor_idx}: "
# f"{self.packet_handler.getTxRxResult(err)}"
# )
# if data_name in CALIBRATION_REQUIRED:
# value = self.apply_calibration([value], [motor_name])[0]
# return value
# def write(self, data_name, value, motor_name: str):
# if data_name in CALIBRATION_REQUIRED:
# value = self.revert_calibration([value], [motor_name])[0]
# motor_idx, model = self.motors[motor_name]
# addr, bytes = self.model_ctrl_table[model][data_name]
# args = (self.port_handler, motor_idx, addr, value)
# if bytes == 1:
# comm, err = self.packet_handler.write1ByteTxRx(*args)
# elif bytes == 2:
# comm, err = self.packet_handler.write2ByteTxRx(*args)
# elif bytes == 4:
# comm, err = self.packet_handler.write4ByteTxRx(*args)
# else:
# raise NotImplementedError(
# f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but {bytes} "
# f"is provided instead.")
# if comm != COMM_SUCCESS:
# raise ConnectionError(
# f"Write failed due to communication error on port {self.port} for motor {motor_idx}: "
# f"{self.packet_handler.getTxRxResult(comm)}"
# )
# elif err != 0:
# raise ConnectionError(
# f"Write failed due to error {err} on port {self.port} for motor {motor_idx}: "
# f"{self.packet_handler.getTxRxResult(err)}"
# )

View File

@@ -1,713 +0,0 @@
from copy import deepcopy
import enum
import numpy as np
from scservo_sdk import PacketHandler, PortHandler, COMM_SUCCESS, GroupSyncRead, GroupSyncWrite
from scservo_sdk import SCS_HIBYTE, SCS_HIBYTE, SCS_LOBYTE, SCS_LOWORD
PROTOCOL_VERSION = 0
BAUD_RATE = 1_000_000
TIMEOUT_MS = 1000
def u32_to_i32(value: int | np.array) -> int | np.array:
"""
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
"""
if isinstance(value, int):
if value > 2147483647:
value = value - 4294967296
else:
for i in range(len(value)):
if value[i] is not None and value[i] > 2147483647:
value[i] = value[i] - 4294967296
return value
def i32_to_u32(value: int | np.array) -> int | np.array:
"""
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
"""
if isinstance(value, int):
if value < 0:
value = value + 4294967296
else:
for i in range(len(value)):
if value[i] is not None and value[i] < 0:
value[i] = value[i] + 4294967296
return value
def retrieve_ids_and_command(values: np.array, ids: np.array) -> (list[int], np.array):
"""
Convert the values to a chain command. Skip the None values and return the ids and values.
"""
non_none_values = np.array([value for value in values if value is not None])
non_none_values_ids = [ids[i] for i, value in enumerate(values) if value is not None]
return non_none_values_ids, non_none_values
class TorqueMode(enum.Enum):
ENABLED = 1
DISABLED = 0
class OperatingMode(enum.Enum):
pass
class DriveMode(enum.Enum):
pass
SCS_SERIES_CONTROL_TABLE = [
("Model", 3, 2),
("ID", 5, 1),
("Baud_Rate", 6, 1),
("Return_Delay", 7, 1),
("Response_Status_Level", 8, 1),
("Min_Angle_Limit", 9, 2),
("Max_Angle_Limit", 11, 2),
("Max_Temperature_Limit", 13, 1),
("Max_Voltage_Limit", 14, 1),
("Min_Voltage_Limit", 15, 1),
("Max_Torque_Limit", 16, 2),
("Phase", 18, 1),
("Unloading_Condition", 19, 1),
("LED_Alarm_Condition", 20, 1),
("P_Coefficient", 21, 1),
("D_Coefficient", 22, 1),
("I_Coefficient", 23, 1),
("Minimum_Startup_Force", 24, 2),
("CW_Dead_Zone", 26, 1),
("CCW_Dead_Zone", 27, 1),
("Protection_Current", 28, 2),
("Angular_Resolution", 30, 1),
("Offset", 31, 2),
("Mode", 33, 1),
("Protective_Torque", 34, 1),
("Protection_Time", 35, 1),
("Overload_Torque", 36, 1),
("Speed_closed_loop_P_proportional_coefficient", 37, 1),
("Over_Current_Protection_Time", 38, 1),
("Velocity_closed_loop_I_integral_coefficient", 39, 1),
("Torque_Enable", 40, 1),
("Acceleration", 41, 1),
("Goal_Position", 42, 2),
("Goal_Time", 44, 2),
("Goal_Speed", 46, 2),
("Lock", 55, 1),
("Present_Position", 56, 2),
("Present_Speed", 58, 2),
("Present_Load", 60, 2),
("Present_Voltage", 62, 1),
("Present_Temperature", 63, 1),
("Status", 65, 1),
("Moving", 66, 1),
("Present_Current", 69, 2)
]
MODEL_CONTROL_TABLE = {
"scs_series": SCS_SERIES_CONTROL_TABLE,
"sts3215": SCS_SERIES_CONTROL_TABLE,
}
class FeetechBus:
def __init__(self, port: str, motor_models: dict[int, str],
extra_model_control_table: dict[str, list[tuple]] | None = None):
self.port = port
self.motor_models = motor_models
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
if extra_model_control_table:
self.model_ctrl_table.update(extra_model_control_table)
# Find read/write addresses and number of bytes for each motor
self.motor_ctrl = {}
for idx, model in self.motor_models.items():
for data_name, addr, bytes in self.model_ctrl_table[model]:
if idx not in self.motor_ctrl:
self.motor_ctrl[idx] = {}
self.motor_ctrl[idx][data_name] = {
"addr": addr,
"bytes": bytes,
}
self.port_handler = PortHandler(self.port)
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
if not self.port_handler.openPort():
raise OSError(f"Failed to open port {self.port}")
self.port_handler.setBaudRate(BAUD_RATE)
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
self.group_readers = {}
self.group_writers = {}
@property
def motor_ids(self) -> list[int]:
return list(self.motor_models.keys())
def close(self):
self.port_handler.closePort()
def write(self, data_name, value, motor_idx: int):
addr = self.motor_ctrl[motor_idx][data_name]["addr"]
bytes = self.motor_ctrl[motor_idx][data_name]["bytes"]
args = (self.port_handler, motor_idx, addr, value)
if bytes == 1:
comm, err = self.packet_handler.write1ByteTxRx(*args)
elif bytes == 2:
comm, err = self.packet_handler.write2ByteTxRx(*args)
elif bytes == 4:
comm, err = self.packet_handler.write4ByteTxRx(*args)
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but {bytes} "
f"is provided instead.")
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port} for motor {motor_idx}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
elif err != 0:
raise ConnectionError(
f"Write failed due to error {err} on port {self.port} for motor {motor_idx}: "
f"{self.packet_handler.getTxRxResult(err)}"
)
def read(self, data_name, motor_idx: int):
addr = self.motor_ctrl[motor_idx][data_name]["addr"]
bytes = self.motor_ctrl[motor_idx][data_name]["bytes"]
args = (self.port_handler, motor_idx, addr)
if bytes == 1:
value, comm, err = self.packet_handler.read1ByteTxRx(*args)
elif bytes == 2:
value, comm, err = self.packet_handler.read2ByteTxRx(*args)
elif bytes == 4:
value, comm, err = self.packet_handler.read4ByteTxRx(*args)
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead.")
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port} for motor {motor_idx}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
elif err != 0:
raise ConnectionError(
f"Read failed due to error {err} on port {self.port} for motor {motor_idx}: "
f"{self.packet_handler.getTxRxResult(err)}"
)
return value
def sync_read(self, data_name, motor_ids: list[int] | None = None):
if motor_ids is None:
motor_ids = self.motor_ids
group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids])
first_motor_idx = list(self.motor_ctrl.keys())[0]
addr = self.motor_ctrl[first_motor_idx][data_name]["addr"]
bytes = self.motor_ctrl[first_motor_idx][data_name]["bytes"]
if data_name not in self.group_readers:
self.group_readers[group_key] = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
for idx in motor_ids:
self.group_readers[group_key].addParam(idx)
comm = self.group_readers[group_key].txRxPacket()
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = self.group_readers[group_key].getData(idx, addr, bytes)
values.append(value)
return np.array(values)
def sync_write(self, data_name, values: int | list[int], motor_ids: int | list[int] | None = None):
if motor_ids is None:
motor_ids = self.motor_ids
if isinstance(motor_ids, int):
motor_ids = [motor_ids]
if isinstance(values, (int, np.integer)):
values = [int(values)] * len(motor_ids)
if isinstance(values, np.ndarray):
values = values.tolist()
group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids])
first_motor_idx = list(self.motor_ctrl.keys())[0]
addr = self.motor_ctrl[first_motor_idx][data_name]["addr"]
bytes = self.motor_ctrl[first_motor_idx][data_name]["bytes"]
init_group = data_name not in self.group_readers
if init_group:
self.group_writers[group_key] = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
for idx, value in zip(motor_ids, values):
if bytes == 1:
data = [
SCS_LOBYTE(SCS_LOWORD(value)),
]
elif bytes == 2:
data = [
SCS_LOBYTE(SCS_LOWORD(value)),
SCS_HIBYTE(SCS_LOWORD(value)),
]
elif bytes == 4:
data = [
SCS_LOBYTE(SCS_LOWORD(value)),
SCS_HIBYTE(SCS_LOWORD(value)),
SCS_LOBYTE(SCS_HIBYTE(value)),
SCS_HIBYTE(SCS_HIBYTE(value)),
]
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but {bytes} "
f"is provided instead.")
if init_group:
self.group_writers[group_key].addParam(idx, data)
else:
self.group_writers[group_key].changeParam(idx, data)
comm = self.group_writers[group_key].txPacket()
if comm != COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
def read_model(self, motor_idx: int):
return self.read("Model", motor_idx)
def sync_read_model(self, motor_ids: list[int] | None = None):
return self.sync_read("Model", motor_ids)
def write_id(self, value, motor_idx: int):
self.write("ID", value, motor_idx)
def read_id(self, motor_idx: int):
return self.read("ID", motor_idx)
def sync_read_id(self, motor_ids: list[int] | None = None):
return self.sync_read("ID", motor_ids)
def sync_write_id(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("ID", values, motor_ids)
def write_baud_rate(self, value, motor_idx: int):
self.write("Baud_Rate", value, motor_idx)
def read_baud_rate(self, motor_idx: int):
return self.read("Baud_Rate", motor_idx)
def sync_read_baud_rate(self, motor_ids: list[int] | None = None):
return self.sync_read("Baud_Rate", motor_ids)
def sync_write_baud_rate(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Baud_Rate", values, motor_ids)
def read_return_delay(self, motor_idx: int):
return self.read("Return_Delay", motor_idx)
def sync_read_return_delay(self, motor_ids: list[int] | None = None):
return self.sync_read("Return_Delay", motor_ids)
def read_response_status_level(self, motor_idx: int):
return self.read("Response_Status_Level", motor_idx)
def sync_read_response_status_level(self, motor_ids: list[int] | None = None):
return self.sync_read("Response_Status_Level", motor_ids)
def write_min_angle_limit(self, value, motor_idx: int):
self.write("Min_Angle_Limit", value, motor_idx)
def read_min_angle_limit(self, motor_idx: int):
return self.read("Min_Angle_Limit", motor_idx)
def sync_read_min_angle_limit(self, motor_ids: list[int] | None = None):
return self.sync_read("Min_Angle_Limit", motor_ids)
def sync_write_min_angle_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Min_Angle_Limit", values, motor_ids)
def write_max_angle_limit(self, value, motor_idx: int):
self.write("Max_Angle_Limit", value, motor_idx)
def read_max_angle_limit(self, motor_idx: int):
return self.read("Max_Angle_Limit", motor_idx)
def sync_read_max_angle_limit(self, motor_ids: list[int] | None = None):
return self.sync_read("Max_Angle_Limit", motor_ids)
def sync_write_max_angle_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Max_Angle_Limit", values, motor_ids)
def write_max_temperature_limit(self, value, motor_idx: int):
self.write("Max_Temperature_Limit", value, motor_idx)
def read_max_temperature_limit(self, motor_idx: int):
return self.read("Max_Temperature_Limit", motor_idx)
def sync_read_max_temperature_limit(self, motor_ids: list[int] | None = None):
return self.sync_read("Max_Temperature_Limit", motor_ids)
def sync_write_max_temperature_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Max_Temperature_Limit", values, motor_ids)
def write_max_voltage_limit(self, value, motor_idx: int):
self.write("Max_Voltage_Limit", value, motor_idx)
def read_max_voltage_limit(self, motor_idx: int):
return self.read("Max_Voltage_Limit", motor_idx)
def sync_read_max_voltage_limit(self, motor_ids: list[int] | None = None):
return self.sync_read("Max_Voltage_Limit", motor_ids)
def sync_write_max_voltage_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Max_Voltage_Limit", values, motor_ids)
def write_min_voltage_limit(self, value, motor_idx: int):
self.write("Min_Voltage_Limit", value, motor_idx)
def read_min_voltage_limit(self, motor_idx: int):
return self.read("Min_Voltage_Limit", motor_idx)
def sync_read_min_voltage_limit(self, motor_ids: list[int] | None = None):
return self.sync_read("Min_Voltage_Limit", motor_ids)
def sync_write_min_voltage_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Min_Voltage_Limit", values, motor_ids)
def write_max_torque_limit(self, value, motor_idx: int):
self.write("Max_Torque_Limit", value, motor_idx)
def read_max_torque_limit(self, motor_idx: int):
return self.read("Max_Torque_Limit", motor_idx)
def sync_read_max_torque_limit(self, motor_ids: list[int] | None = None):
return self.sync_read("Max_Torque_Limit", motor_ids)
def sync_write_max_torque_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Max_Torque_Limit", values, motor_ids)
def write_p_coefficient(self, value, motor_idx: int):
self.write("P_Coefficient", value, motor_idx)
def read_p_coefficient(self, motor_idx: int):
return self.read("P_Coefficient", motor_idx)
def sync_read_p_coefficient(self, motor_ids: list[int] | None = None):
return self.sync_read("P_Coefficient", motor_ids)
def sync_write_p_coefficient(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("P_Coefficient", values, motor_ids)
def write_d_coefficient(self, value, motor_idx: int):
self.write("D_Coefficient", value, motor_idx)
def read_d_coefficient(self, motor_idx: int):
return self.read("D_Coefficient", motor_idx)
def sync_read_d_coefficient(self, motor_ids: list[int] | None = None):
return self.sync_read("D_Coefficient", motor_ids)
def sync_write_d_coefficient(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("D_Coefficient", values, motor_ids)
def write_i_coefficient(self, value, motor_idx: int):
self.write("I_Coefficient", value, motor_idx)
def read_i_coefficient(self, motor_idx: int):
return self.read("I_Coefficient", motor_idx)
def sync_read_i_coefficient(self, motor_ids: list[int] | None = None):
return self.sync_read("I_Coefficient", motor_ids)
def sync_write_i_coefficient(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("I_Coefficient", values, motor_ids)
def write_minimum_startup_force(self, value, motor_idx: int):
self.write("Minimum_Startup_Force", value, motor_idx)
def read_minimum_startup_force(self, motor_idx: int):
return self.read("Minimum_Startup_Force", motor_idx)
def sync_read_minimum_startup_force(self, motor_ids: list[int] | None = None):
return self.sync_read("Minimum_Startup_Force", motor_ids)
def sync_write_minimum_startup_force(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Minimum_Startup_Force", values, motor_ids)
def write_cw_dead_zone(self, value, motor_idx: int):
self.write("CW_Dead_Zone", value, motor_idx)
def read_cw_dead_zone(self, motor_idx: int):
return self.read("CW_Dead_Zone", motor_idx)
def sync_read_cw_dead_zone(self, motor_ids: list[int] | None = None):
return self.sync_read("CW_Dead_Zone", motor_ids)
def sync_write_cw_dead_zone(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("CW_Dead_Zone", values, motor_ids)
def write_ccw_dead_zone(self, value, motor_idx: int):
self.write("CCW_Dead_Zone", value, motor_idx)
def read_ccw_dead_zone(self, motor_idx: int):
return self.read("CCW_Dead_Zone", motor_idx)
def sync_read_ccw_dead_zone(self, motor_ids: list[int] | None = None):
return self.sync_read("CCW_Dead_Zone", motor_ids)
def sync_write_ccw_dead_zone(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("CCW_Dead_Zone", values, motor_ids)
def write_protection_current(self, value, motor_idx: int):
self.write("Protection_Current", value, motor_idx)
def read_protection_current(self, motor_idx: int):
return self.read("Protection_Current", motor_idx)
def sync_read_protection_current(self, motor_ids: list[int] | None = None):
return self.sync_read("Protection_Current", motor_ids)
def sync_write_protection_current(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Protection_Current", values, motor_ids)
def read_angular_resolution(self, motor_idx: int):
return self.read("Angular_Resolution", motor_idx)
def sync_read_angular_resolution(self, motor_ids: list[int] | None = None):
return self.sync_read("Angular_Resolution", motor_ids)
def write_offset(self, value, motor_idx: int):
self.write("Offset", value, motor_idx)
def read_offset(self, motor_idx: int):
return self.read("Offset", motor_idx)
def sync_read_offset(self, motor_ids: list[int] | None = None):
return self.sync_read("Offset", motor_ids)
def sync_write_offset(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Offset", values, motor_ids)
def write_mode(self, value, motor_idx: int):
self.write("Mode", value, motor_idx)
def read_mode(self, motor_idx: int):
return self.read("Mode", motor_idx)
def sync_read_mode(self, motor_ids: list[int] | None = None):
return self.sync_read("Mode", motor_ids)
def sync_write_mode(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Mode", values, motor_ids)
def write_protective_torque(self, value, motor_idx: int):
self.write("Protective_Torque", value, motor_idx)
def read_protective_torque(self, motor_idx: int):
return self.read("Protective_Torque", motor_idx)
def sync_read_protective_torque(self, motor_ids: list[int] | None = None):
return self.sync_read("Protective_Torque", motor_ids)
def sync_write_protective_torque(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Protective_Torque", values, motor_ids)
def read_protection_time(self, motor_idx: int):
return self.read("Protection_Time", motor_idx)
def sync_read_protection_time(self, motor_ids: list[int] | None = None):
return self.sync_read("Protection_Time", motor_ids)
def write_speed_closed_loop_p_proportional_coefficient(self, value, motor_idx: int):
self.write("Speed_closed_loop_P_proportional_coefficient", value, motor_idx)
def read_speed_closed_loop_p_proportional_coefficient(self, motor_idx: int):
return self.read("Speed_closed_loop_P_proportional_coefficient", motor_idx)
def sync_read_speed_closed_loop_p_proportional_coefficient(self, motor_ids: list[int] | None = None):
return self.sync_read("Speed_closed_loop_P_proportional_coefficient", motor_ids)
def sync_write_speed_closed_loop_p_proportional_coefficient(self, values: int | list[int],
motor_ids: list[int] | None = None):
self.sync_write("Speed_closed_loop_P_proportional_coefficient", values, motor_ids)
def write_over_current_protection_time(self, value, motor_idx: int):
self.write("Over_Current_Protection_Time", value, motor_idx)
def read_over_current_protection_time(self, motor_idx: int):
return self.read("Over_Current_Protection_Time", motor_idx)
def sync_read_over_current_protection_time(self, motor_ids: list[int] | None = None):
return self.sync_read("Over_Current_Protection_Time", motor_ids)
def sync_write_over_current_protection_time(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Over_Current_Protection_Time", values, motor_ids)
def write_velocity_closed_loop_i_integral_coefficient(self, value, motor_idx: int):
self.write("Velocity_closed_loop_I_integral_coefficient", value, motor_idx)
def read_velocity_closed_loop_i_integral_coefficient(self, motor_idx: int):
return self.read("Velocity_closed_loop_I_integral_coefficient", motor_idx)
def sync_read_velocity_closed_loop_i_integral_coefficient(self, motor_ids: list[int] | None = None):
return self.sync_read("Velocity_closed_loop_I_integral_coefficient", motor_ids)
def sync_write_velocity_closed_loop_i_integral_coefficient(self, values: int | list[int],
motor_ids: list[int] | None = None):
self.sync_write("Velocity_closed_loop_I_integral_coefficient", values, motor_ids)
def write_torque_enable(self, value, motor_idx: int):
self.write("Torque_Enable", value, motor_idx)
def read_torque_enable(self, motor_idx: int):
return self.read("Torque_Enable", motor_idx)
def sync_read_torque_enable(self, motor_ids: list[int] | None = None):
return self.sync_read("Torque_Enable", motor_ids)
def sync_write_torque_enable(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Torque_Enable", values, motor_ids)
def write_goal_position_u32(self, value, motor_idx: int):
self.write("Goal_Position", value, motor_idx)
def write_goal_position_i32(self, value, motor_idx: int):
self.write("Goal_Position", i32_to_u32(value), motor_idx)
def read_goal_position_u32(self, motor_idx: int):
return self.read("Goal_Position", motor_idx)
def read_goal_position_i32(self, motor_idx: int):
goal_position_u32 = self.read_goal_position_u32(motor_idx)
return u32_to_i32(goal_position_u32)
def sync_read_goal_position_u32(self, motor_ids: list[int] | None = None):
return self.sync_read("Goal_Position", motor_ids)
def sync_read_goal_position_i32(self, motor_ids: list[int] | None = None):
goal_position_u32 = self.sync_read_goal_position_u32(motor_ids)
return u32_to_i32(goal_position_u32)
def sync_write_goal_position_u32(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Goal_Position", values, motor_ids)
def sync_write_goal_position_i32(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Goal_Position", i32_to_u32(values), motor_ids)
def write_goal_time(self, value, motor_idx: int):
self.write("Goal_Time", value, motor_idx)
def read_goal_time(self, motor_idx: int):
return self.read("Goal_Time", motor_idx)
def sync_read_goal_time(self, motor_ids: list[int] | None = None):
return self.sync_read("Goal_Time", motor_ids)
def sync_write_goal_time(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Goal_Time", values, motor_ids)
def write_goal_speed(self, value, motor_idx: int):
self.write("Goal_Speed", value, motor_idx)
def read_goal_speed(self, motor_idx: int):
return self.read("Goal_Speed", motor_idx)
def sync_read_goal_speed(self, motor_ids: list[int] | None = None):
return self.sync_read("Goal_Speed", motor_ids)
def sync_write_goal_speed(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Goal_Speed", values, motor_ids)
def write_lock(self, value, motor_idx: int):
self.write("Lock", value, motor_idx)
def read_lock(self, motor_idx: int):
return self.read("Lock", motor_idx)
def sync_read_lock(self, motor_ids: list[int] | None = None):
return self.sync_read("Lock", motor_ids)
def sync_write_lock(self, values: int | list[int], motor_ids: list[int] | None = None):
self.sync_write("Lock", values, motor_ids)
def read_present_position_u32(self, motor_idx: int):
return self.read("Present_Position", motor_idx)
def read_present_position_i32(self, motor_idx: int):
present_position_u32 = self.read_present_position_u32(motor_idx)
return u32_to_i32(present_position_u32)
def sync_read_present_position_u32(self, motor_ids: list[int] | None = None):
return self.sync_read("Present_Position", motor_ids)
def sync_read_present_position_i32(self, motor_ids: list[int] | None = None):
present_position_u32 = self.sync_read_present_position_u32(motor_ids)
return u32_to_i32(present_position_u32)
def read_present_speed(self, motor_idx: int):
return self.read("Present_Speed", motor_idx)
def sync_read_present_speed(self, motor_ids: list[int] | None = None):
return self.sync_read("Present_Speed", motor_ids)
def read_present_load(self, motor_idx: int):
return self.read("Present_Load", motor_idx)
def sync_read_present_load(self, motor_ids: list[int] | None = None):
return self.sync_read("Present_Load", motor_ids)
def read_present_voltage(self, motor_idx: int):
return self.read("Present_Voltage", motor_idx)
def sync_read_present_voltage(self, motor_ids: list[int] | None = None):
return self.sync_read("Present_Voltage", motor_ids)
def read_present_temperature(self, motor_idx: int):
return self.read("Present_Temperature", motor_idx)
def sync_read_present_temperature(self, motor_ids: list[int] | None = None):
return self.sync_read("Present_Temperature", motor_ids)
def read_moving(self, motor_idx: int):
return self.read("Moving", motor_idx)
def sync_read_moving(self, motor_ids: list[int] | None = None):
return self.sync_read("Moving", motor_ids)
def read_present_current(self, motor_idx: int):
return self.read("Present_Current", motor_idx)
def sync_read_present_current(self, motor_ids: list[int] | None = None):
return self.sync_read("Present_Current", motor_ids)

View File

@@ -1,9 +0,0 @@
from typing import Protocol
class MotorsBus(Protocol):
def motor_names(self): ...
def set_calibration(self): ...
def apply_calibration(self): ...
def revert_calibration(self): ...
def read(self): ...
def write(self): ...

View File

@@ -1,204 +0,0 @@
import copy
from dataclasses import dataclass, field, replace
import numpy as np
import torch
from examples.real_robot_example.gym_real_world.robot import Robot
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
from lerobot.common.robot_devices.cameras.utils import Camera
MAX_LEADER_GRIPPER_RAD = 0.7761942786701344
MAX_LEADER_GRIPPER_POS = 2567
MAX_FOLLOWER_GRIPPER_RAD = 1.6827769243105486
MAX_FOLLOWER_GRIPPER_POS = 3100
MIN_LEADER_GRIPPER_RAD = -0.12732040539450828
MIN_LEADER_GRIPPER_POS = 1984
MIN_FOLLOWER_GRIPPER_RAD = 0.6933593161243099
MIN_FOLLOWER_GRIPPER_POS = 2512
GRIPPER_INDEX = -1
def convert_gripper_range_from_leader_to_follower(leader_pos):
follower_goal_pos = copy.copy(leader_pos)
follower_goal_pos[GRIPPER_INDEX] = \
(leader_pos[GRIPPER_INDEX] - MIN_LEADER_GRIPPER_POS) \
/ (MAX_LEADER_GRIPPER_POS - MIN_LEADER_GRIPPER_POS) \
* (MAX_FOLLOWER_GRIPPER_POS - MIN_FOLLOWER_GRIPPER_POS) \
+ MIN_FOLLOWER_GRIPPER_POS
return follower_goal_pos
@dataclass
class AlohaRobotConfig:
"""
Example of usage:
```python
AlohaRobotConfig()
```
Example of only using left arm:
```python
AlohaRobotConfig(
activated_leaders=["left"],
activated_followers=["left"],
)
```
"""
# Define all the components of the robot
leader_devices: dict[str, str] = field(
default_factory=lambda: {
"right": {
#"port": "/dev/ttyDXL_master_right",
"port": "/dev/ttyDXL_master_left",
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
},
"left": {
"port": "/dev/ttyDXL_master_left",
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
},
}
)
follower_devices: dict[str, str] = field(
default_factory=lambda: {
"right": {
"port": "/dev/ttyDXL_puppet_right",
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
},
"left": {
"port": "/dev/ttyDXL_puppet_left",
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
},
}
)
camera_devices: dict[str, Camera] = field(
default_factory=lambda: {
# "cam_high": OpenCVCamera(16),
# "cam_low": OpenCVCamera(4),
# "cam_left_wrist": OpenCVCamera(10),
# "cam_right_wrist": OpenCVCamera(22),
}
)
# Allows to easily pick a subset of all devices
activated_leaders: list[str] | None = field(
default_factory=lambda: ["left", "right"]
)
activated_followers: list[str] | None = field(
default_factory=lambda: ["left", "right"]
)
activated_cameras: list[str] | None = field(
default_factory=lambda: ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
)
class AlohaRobot():
""" Trossen Robotics
Example of usage:
```python
robot = AlohaRobot()
```
"""
def __init__(self, config: AlohaRobotConfig | None = None, **kwargs):
if config is None:
config = AlohaRobotConfig()
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.config = config
self.leaders = {}
self.followers = {}
self.cameras = {}
if config.activated_leaders:
for name in config.activated_leaders:
info = config.leader_devices[name]
self.leaders[name] = Robot(info["port"], servo_ids=info["servos"])
if config.activated_followers:
for name in config.activated_followers:
info = config.follower_devices[name]
self.followers[name] = Robot(info["port"], servo_ids=info["servos"])
if config.activated_cameras:
for name in config.activated_cameras:
self.cameras[name] = config.camera_devices[name]
def init_teleop(self):
for name in self.followers:
self.followers[name]._enable_torque()
for name in self.cameras:
self.cameras[name].connect()
def teleop_step(self, record_data=False) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
# Prepare to assign the positions of the leader to the follower
leader_pos = {}
for name in self.leaders:
leader_pos[name] = self.leaders[name].read_position()
# Update the position of the follower gripper to account for the different minimum and maximum range
# position in range [0, 4096[ which corresponds to 4096 bins of 360 degrees
# for all our dynamixel servors
# gripper id=8 has a different range from leader to follower
follower_goal_pos = {}
for name in self.leaders:
follower_goal_pos[name] = convert_gripper_range_from_leader_to_follower(leader_pos[name])
# Send action
for name in self.followers:
self.followers[name].set_goal_pos(follower_goal_pos[name])
# Early exit when recording data is not requested
if not record_data:
return
# Read follower position
follower_pos = {}
for name in self.followers:
follower_pos[name] = self.followers[name].read_position()
# Create state by concatenating follower current position
state = []
for name in ["left", "right"]:
if name in follower_pos:
state.append(follower_pos[name])
state = np.concatenate(state)
state = pwm2pos(state)
# Create action by concatenating follower goal position
action = []
for name in ["left", "right"]:
if name in follower_goal_pos:
action.append(follower_goal_pos[name])
action = np.concatenate(action)
action = pwm2pos(action)
# Capture images from cameras
images = {}
for name in self.cameras:
images[name] = self.cameras[name].read()
# Populate output dictionnaries and format to pytorch
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = torch.from_numpy(state)
action_dict["action"] = torch.from_numpy(action)
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
return obs_dict, action_dict
def send_action(self, action):
from_idx = 0
to_idx = 0
follower_goal_pos = {}
for name in ["left", "right"]:
if name in self.followers:
to_idx += len(self.config.follower_devices[name]["servos"])
follower_goal_pos[name] = pos2pwm(action[from_idx:to_idx].numpy())
from_idx = to_idx
for name in self.followers:
self.followers[name].set_goal_pos(follower_goal_pos[name])

View File

@@ -1,46 +0,0 @@
def make_robot(name):
if name == "koch":
from lerobot.common.robot_devices.robots.koch import KochRobot
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
robot = KochRobot(
leader_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl330-m077"),
"shoulder_lift": (2, "xl330-m077"),
"elbow_flex": (3, "xl330-m077"),
"wrist_flex": (4, "xl330-m077"),
"wrist_roll": (5, "xl330-m077"),
"gripper": (6, "xl330-m077"),
},
),
},
follower_arms={
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
},
),
},
cameras={
"main": OpenCVCamera(1, fps=30, width=640, height=480),
}
)
else:
raise ValueError(f"Robot '{name}' not found.")
return robot

View File

@@ -1,365 +0,0 @@
import copy
from dataclasses import dataclass, field, replace
from pathlib import Path
import pickle
import numpy as np
import torch
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode, motor_position_to_angle
from lerobot.common.robot_devices.motors.utils import MotorsBus
########################################################################
# Calibration logic
########################################################################
# TARGET_HORIZONTAL_POSITION = motor_position_to_angle(np.array([0, -1024, 1024, 0, -1024, 0]))
# TARGET_90_DEGREE_POSITION = motor_position_to_angle(np.array([1024, 0, 0, 1024, 0, -1024]))
# GRIPPER_OPEN = motor_position_to_angle(np.array([-400]))
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
GRIPPER_OPEN = np.array([-400])
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
for i in range(len(values)):
if values[i] is not None:
values[i] += homing_offset[i]
return values
def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
for i in range(len(values)):
if values[i] is not None and drive_mode[i]:
values[i] = -values[i]
return values
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
values = apply_drive_mode(values, drive_mode)
values = apply_homing_offset(values, homing_offset)
return values
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
"""
Transform working position into real position for the robot.
"""
values = apply_homing_offset(values, np.array([
-homing_offset if homing_offset is not None else None for homing_offset in homing_offset
]))
values = apply_drive_mode(values, drive_mode)
return values
def revert_appropriate_positions(positions: np.array, drive_mode: list[bool]) -> np.array:
for i, revert in enumerate(drive_mode):
if not revert and positions[i] is not None:
positions[i] = -positions[i]
return positions
def compute_corrections(positions: np.array, drive_mode: list[bool], target_position: np.array) -> np.array:
correction = revert_appropriate_positions(positions, drive_mode)
for i in range(len(positions)):
if correction[i] is not None:
if drive_mode[i]:
correction[i] -= target_position[i]
else:
correction[i] += target_position[i]
return correction
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
return np.array(
[round(positions[i] / 1024) * 1024 if positions[i] is not None else None for i in range(len(positions))])
def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array) -> np.array:
# Get the present positions of the servos
present_positions = apply_calibration(
arm.read("Present_Position"),
np.array([0, 0, 0, 0, 0, 0]),
drive_mode)
nearest_positions = compute_nearest_rounded_positions(present_positions)
correction = compute_corrections(nearest_positions, drive_mode, target_position)
return correction
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
# Get current positions
present_positions = apply_calibration(
arm.read("Present_Position"),
offset,
np.array([False, False, False, False, False, False]))
nearest_positions = compute_nearest_rounded_positions(present_positions)
# construct 'drive_mode' list comparing nearest_positions and TARGET_90_DEGREE_POSITION
drive_mode = []
for i in range(len(nearest_positions)):
drive_mode.append(nearest_positions[i] != TARGET_90_DEGREE_POSITION[i])
return drive_mode
def reset_arm(arm: MotorsBus):
# To be configured, all servos must be in "torque disable" mode
arm.write("Torque_Enable", TorqueMode.DISABLED.value)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
# TODO(rcadene): why?
# Use 'position control current based' for gripper
arm.write("Operating_Mode", OperatingMode.CURRENT_CONTROLLED_POSITION.value, "gripper")
# Make sure the native calibration (homing offset abd drive mode) is disabled, since we use our own calibration layer to be more generic
arm.write("Homing_Offset", 0)
arm.write("Drive_Mode", DriveMode.NON_INVERTED.value)
def run_arm_calibration(arm: MotorsBus, name: str):
reset_arm(arm)
# TODO(rcadene): document what position 1 mean
print(f"Please move the '{name}' arm to the horizontal position (gripper fully closed)")
input("Press Enter to continue...")
horizontal_homing_offset = compute_homing_offset(arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION)
# TODO(rcadene): document what position 2 mean
print(f"Please move the '{name}' arm to the 90 degree position (gripper fully open)")
input("Press Enter to continue...")
drive_mode = compute_drive_mode(arm, horizontal_homing_offset)
homing_offset = compute_homing_offset(arm, drive_mode, TARGET_90_DEGREE_POSITION)
# Invert offset for all drive_mode servos
for i in range(len(drive_mode)):
if drive_mode[i]:
homing_offset[i] = -homing_offset[i]
print("Calibration is done!")
print("=====================================")
print(" HOMING_OFFSET: ", " ".join([str(i) for i in homing_offset]))
print(" DRIVE_MODE: ", " ".join([str(i) for i in drive_mode]))
print("=====================================")
return homing_offset, drive_mode
########################################################################
# Alexander Koch robot arm
########################################################################
@dataclass
class KochRobotConfig:
"""
Example of usage:
```python
KochRobotConfig()
```
"""
# Define all the components of the robot
leader_arms: dict[str, MotorsBus] = field(
default_factory=lambda: {
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl330-m077"),
"shoulder_lift": (2, "xl330-m077"),
"elbow_flex": (3, "xl330-m077"),
"wrist_flex": (4, "xl330-m077"),
"wrist_roll": (5, "xl330-m077"),
"gripper": (6, "xl330-m077"),
},
),
}
)
follower_arms: dict[str, MotorsBus] = field(
default_factory=lambda: {
"main": DynamixelMotorsBus(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": (1, "xl430-w250"),
"shoulder_lift": (2, "xl430-w250"),
"elbow_flex": (3, "xl330-m288"),
"wrist_flex": (4, "xl330-m288"),
"wrist_roll": (5, "xl330-m288"),
"gripper": (6, "xl330-m288"),
},
),
}
)
cameras: dict[str, Camera] = field(
default_factory=lambda: {}
)
class KochRobot():
""" Tau Robotics: https://tau-robotics.com
Example of usage:
```python
robot = KochRobot()
```
"""
def __init__(self, config: KochRobotConfig | None = None, calibration_path: Path = ".cache/calibration/koch.pkl", **kwargs):
if config is None:
config = KochRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.calibration_path = Path(calibration_path)
self.leader_arms = self.config.leader_arms
self.follower_arms = self.config.follower_arms
self.cameras = self.config.cameras
def init_teleop(self):
if self.calibration_path.exists():
# Reset all arms before setting calibration
for name in self.follower_arms:
reset_arm(self.follower_arms[name])
for name in self.leader_arms:
reset_arm(self.leader_arms[name])
with open(self.calibration_path, 'rb') as f:
calibration = pickle.load(f)
else:
calibration = self.run_calibration()
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.calibration_path, 'wb') as f:
pickle.dump(calibration, f)
for name in self.follower_arms:
self.follower_arms[name].set_calibration(calibration[f"follower_{name}"])
self.follower_arms[name].write("Torque_Enable", 1)
for name in self.leader_arms:
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
# TODO(rcadene): add comments
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN, "gripper")
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
for name in self.cameras:
self.cameras[name].connect()
def run_calibration(self):
calibration = {}
for name in self.follower_arms:
homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], f"{name} follower")
calibration[f"follower_{name}"] = {}
for idx, motor_name in enumerate(self.follower_arms[name].motor_names):
calibration[f"follower_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
for name in self.leader_arms:
homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], f"{name} leader")
calibration[f"leader_{name}"] = {}
for idx, motor_name in enumerate(self.leader_arms[name].motor_names):
calibration[f"leader_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
return calibration
def teleop_step(self, record_data=False) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
# Prepare to assign the positions of the leader to the follower
leader_pos = {}
for name in self.leader_arms:
leader_pos[name] = self.leader_arms[name].read("Present_Position")
follower_goal_pos = {}
for name in self.leader_arms:
follower_goal_pos[name] = leader_pos[name]
# Send action
for name in self.follower_arms:
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
# Early exit when recording data is not requested
if not record_data:
return
# Read follower position
follower_pos = {}
for name in self.follower_arms:
follower_pos[name] = self.follower_arms[name].read("Present_Position")
# Create state by concatenating follower current position
state = []
for name in self.follower_arms:
if name in follower_pos:
state.append(follower_pos[name])
state = np.concatenate(state)
# Create action by concatenating follower goal position
action = []
for name in self.follower_arms:
if name in follower_goal_pos:
action.append(follower_goal_pos[name])
action = np.concatenate(action)
# Capture images from cameras
images = {}
for name in self.cameras:
images[name] = self.cameras[name].read()
# Populate output dictionnaries and format to pytorch
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = torch.from_numpy(state)
action_dict["action"] = torch.from_numpy(action)
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
return obs_dict, action_dict
def capture_observation(self):
# Read follower position
follower_pos = {}
for name in self.follower_arms:
follower_pos[name] = self.follower_arms[name].read("Present_Position")
# Create state by concatenating follower current position
state = []
for name in self.follower_arms:
if name in follower_pos:
state.append(follower_pos[name])
state = np.concatenate(state)
# Capture images from cameras
images = {}
for name in self.cameras:
images[name] = self.cameras[name].read()
# Populate output dictionnaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = torch.from_numpy(state)
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
return obs_dict
def send_action(self, action):
from_idx = 0
to_idx = 0
follower_goal_pos = {}
for name in self.follower_arms:
if name in self.follower_arms:
to_idx += len(self.follower_arms[name].motor_names)
follower_goal_pos[name] = action[from_idx:to_idx].numpy()
from_idx = to_idx
for name in self.follower_arms:
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))

View File

@@ -1,8 +0,0 @@
from typing import Protocol
class Robot(Protocol):
def init_teleop(self): ...
def run_calibration(self): ...
def teleop_step(self, record_data=False): ...
def capture_observation(self): ...
def send_action(self, action): ...

View File

@@ -25,42 +25,3 @@ def write_video(video_path, stacked_frames, fps):
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
)
imageio.mimsave(video_path, stacked_frames, fps=fps)
import serial
import os
import time
def reset_usb_port(port):
try:
# Close the serial port if it's open
ser = serial.Serial(port)
ser.close()
except serial.serialutil.SerialException as e:
print(f"Exception while closing the port: {e}")
# Find the USB device path
usb_device_path = None
for root, dirs, files in os.walk('/sys/bus/usb/drivers/usb'):
for dir_name in dirs:
if port in dir_name:
usb_device_path = os.path.join(root, dir_name)
break
if usb_device_path:
# Unbind and rebind the USB device
try:
unbind_path = os.path.join(usb_device_path, 'unbind')
bind_path = os.path.join(usb_device_path, 'bind')
usb_id = os.path.basename(usb_device_path)
with open(unbind_path, 'w') as f:
f.write(usb_id)
time.sleep(1) # Wait for a second
with open(bind_path, 'w') as f:
f.write(usb_id)
print(f"USB port {port} has been reset.")
except Exception as e:
print(f"Exception during USB reset: {e}")
else:
print(f"Could not find USB device path for port: {port}")

View File

@@ -19,7 +19,7 @@ import random
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any, Generator
from typing import Generator
import hydra
import numpy as np
@@ -48,38 +48,12 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
return device
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_global_random_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_global_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed_all(seed)
@contextmanager
@@ -95,10 +69,16 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state_dict = get_global_random_state()
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed)
yield None
set_global_random_state(random_state_dict)
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging():
@@ -120,13 +100,13 @@ def init_logging():
logging.getLogger().addHandler(console_handler)
def format_big_number(num, precision=0):
def format_big_number(num):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0
for suffix in suffixes:
if abs(num) < divisor:
return f"{num:.{precision}f}{suffix}"
return f"{num:.0f}{suffix}"
num /= divisor
return num

View File

@@ -5,33 +5,18 @@ defaults:
hydra:
run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
job:
name: default
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `hydra.run.dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: false
device: cuda # cpu
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: ???
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datsets are provided.
dataset_repo_id: lerobot/pusht
training:
offline_steps: ???
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
online_steps: ???
online_steps_between_rollouts: ???
online_sampling_ratio: 0.5
@@ -40,9 +25,7 @@ training:
eval_freq: ???
save_freq: ???
log_freq: 250
save_checkpoint: true
num_workers: 4
batch_size: ???
save_model: true
eval:
n_episodes: 1
@@ -50,12 +33,10 @@ eval:
batch_size: 1
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: false
# Specify the number of episodes to render during evaluation.
max_episodes_rendered: 10
wandb:
enable: false
# Set to true to disable saving an artifact despite save_checkpoint == True
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot
notes: ""

View File

@@ -5,11 +5,10 @@ fps: 50
env:
name: aloha
task: AlohaInsertion-v0
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
episode_length: 400
fps: ${fps}
state_dim: 14
action_dim: 14
fps: ${fps}
episode_length: 400
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

14
lerobot/configs/env/dora.yaml vendored Normal file
View File

@@ -0,0 +1,14 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha-v0
# from_pixels: True
# pixels_only: False
# image_size: [3, 480, 640]
episode_length: 400
# fps: ${fps}
# state_dim: 14
# action_dim: 14

View File

@@ -1,14 +0,0 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha-v0
state_dim: 14
action_dim: 14
fps: ${fps}
episode_length: 400
real_world: true
gym:
fps: ${fps}

View File

@@ -5,14 +5,10 @@ fps: 10
env:
name: pusht
task: PushT-v0
from_pixels: True
pixels_only: False
image_size: 96
episode_length: 300
fps: ${fps}
state_dim: 2
action_dim: 2
fps: ${fps}
episode_length: 300
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array
visualization_width: 384
visualization_height: 384

View File

@@ -5,14 +5,10 @@ fps: 15
env:
name: xarm
task: XarmLift-v0
from_pixels: True
pixels_only: False
image_size: 84
episode_length: 25
fps: ${fps}
state_dim: 4
action_dim: 4
fps: ${fps}
episode_length: 25
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array
visualization_width: 384
visualization_height: 384

View File

@@ -15,7 +15,7 @@ training:
eval_freq: 10000
save_freq: 100000
log_freq: 250
save_checkpoint: true
save_model: true
batch_size: 8
lr: 1e-5

View File

@@ -1,115 +0,0 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=dora_aloha_real
# ```
seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup
override_dataset_stats:
observation.images.cam_right_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_left_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 10000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_right_wrist: [3, 480, 640]
observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_right_wrist: mean_std
observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -1,19 +1,7 @@
# @package _global_
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
# overfits to the state, because the images are more complex to learn from since they are moving.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real_no_state \
# env=dora_aloha_real
# ```
seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup
dataset_repo_id: cadene/aloha_v2_static_dora_test
override_dataset_stats:
observation.images.cam_right_wrist:
@@ -36,10 +24,10 @@ override_dataset_stats:
training:
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 10000
eval_freq: 99999999999999
save_freq: 1000
log_freq: 100
save_checkpoint: true
save_model: true
batch_size: 8
lr: 1e-5
@@ -70,6 +58,7 @@ policy:
observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
@@ -79,6 +68,7 @@ policy:
observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std

View File

@@ -27,7 +27,7 @@ training:
eval_freq: 5000
save_freq: 5000
log_freq: 250
save_checkpoint: true
save_model: true
batch_size: 64
grad_clip_norm: 10
@@ -44,10 +44,6 @@ training:
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval:
n_episodes: 50
batch_size: 50

View File

@@ -5,8 +5,7 @@ dataset_repo_id: lerobot/xarm_lift_medium
training:
offline_steps: 25000
# TODO(alexander-soare): uncomment when online training gets reinstated
online_steps: 0 # 25000 not implemented yet
online_steps: 25000
eval_freq: 5000
online_steps_between_rollouts: 1
online_sampling_ratio: 0.5

View File

@@ -1,400 +0,0 @@
"""
Example of usage:
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
```bash
python lerobot/scripts/control_robot.py teleoperate
```
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--fps 30
```
- Record one episode in order to test replay:
```bash
python lerobot/scripts/control_robot.py record_dataset \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--num-episodes 1 \
--run-compute-stats 0
```
- Visualize dataset:
```bash
python lerobot/scripts/visualize_dataset.py \
--root tmp/data \
--repo-id $USER/koch_test \
--episode-index 0
```
- Replay this test episode:
```bash
python lerobot/scripts/control_robot.py replay_episode \
--fps 30 \
--root tmp/data \
--repo-id $USER/koch_test \
--episode 0
```
- Record a full dataset in order to train a policy:
```bash
python lerobot/scripts/control_robot.py record_dataset \
--fps 30 \
--root data \
--repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \
--run-compute-stats 1
```
- Train on this dataset (TODO(rcadene)):
```bash
python lerobot/scripts/train.py
```
- Run the pretrained policy on the robot:
```bash
python lerobot/scripts/control_robot.py run_policy \
-p TODO(rcadene)
```
"""
import argparse
from contextlib import nullcontext
import os
from pathlib import Path
import shutil
import time
from PIL import Image
from omegaconf import DictConfig
import torch
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, load_hf_dataset
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
from lerobot.scripts.push_dataset_to_hub import save_meta_data
from lerobot.scripts.robot_controls.record_dataset import record_dataset
import concurrent.futures
########################################################################################
# Utilities
########################################################################################
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
img = Image.fromarray(img_tensor.numpy())
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def busy_wait(seconds):
# Significantly more accurate than `time.sleep`, and mendatory for our use case,
# but it consumes CPU cycles.
# TODO(rcadene): find an alternative
end_time = time.perf_counter() + seconds
while time.perf_counter() < end_time:
pass
def none_or_int(value):
if value == 'None':
return None
return int(value)
########################################################################################
# Control modes
########################################################################################
def teleoperate(robot: Robot, fps: int | None = None):
robot.init_teleop()
while True:
now = time.perf_counter()
robot.teleop_step()
if fps is not None:
dt_s = (time.perf_counter() - now)
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, num_episodes=50, video=True, run_compute_stats=True):
if not video:
raise NotImplementedError()
robot.init_teleop()
local_dir = Path(root) / repo_id
if local_dir.exists():
shutil.rmtree(local_dir)
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
start_time = time.perf_counter()
is_warmup_print = False
is_record_print = False
ep_dicts = []
# Save images using threads to reach high fps (30 and more)
# Using `with` ensures the program exists smoothly if an execption is raised.
with concurrent.futures.ThreadPoolExecutor() as executor:
for episode_index in range(num_episodes):
ep_dict = {}
frame_index = 0
while True:
if not is_warmup_print:
print("Warming up by skipping frames")
os.system('say "Warmup"')
is_warmup_print = True
now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True)
timestamp = time.perf_counter() - start_time
if timestamp < warmup_time_s:
dt_s = (time.perf_counter() - now)
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
continue
if not is_record_print:
print("Recording")
os.system(f'say "Recording episode {episode_index}"')
is_record_print = True
image_keys = [key for key in observation if "image" in key]
not_image_keys = [key for key in observation if "image" not in key]
for key in image_keys:
executor.submit(save_image, observation[key], key, frame_index, episode_index, videos_dir)
for key in not_image_keys:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(observation[key])
for key in action:
if key not in ep_dict:
ep_dict[key] = []
ep_dict[key].append(action[key])
frame_index += 1
dt_s = (time.perf_counter() - now)
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
if timestamp > episode_time_s - warmup_time_s:
break
print("Encoding to `LeRobotDataset` format")
os.system('say "Encoding"')
num_frames = frame_index
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
# shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[key] = []
for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
for key in not_image_keys:
ep_dict[key] = torch.stack(ep_dict[key])
for key in action:
ep_dict[key] = torch.stack(ep_dict[key])
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
ep_dict["next.done"] = done
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
}
meta_data_dir = local_dir / "meta_data"
for key in image_keys:
time.sleep(10)
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
# shutil.rmtree(tmp_imgs_dir)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
if run_compute_stats:
stats = compute_stats(lerobot_dataset)
else:
stats = {}
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
save_meta_data(info, stats, episode_data_index, meta_data_dir)
# TODO(rcadene): push to hub
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action")
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
robot.init_teleop()
print("Replaying episode")
os.system('say "Replaying episode"')
for idx in range(from_idx, to_idx):
now = time.perf_counter()
action = items[idx]["action"]
robot.send_action(action)
dt_s = (time.perf_counter() - now)
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
policy.eval()
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
fps = hydra_cfg.env.fps
while True:
now = time.perf_counter()
observation = robot.capture_observation()
with torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
action = policy.select_action(observation)
robot.send_action(action)
dt_s = (time.perf_counter() - now)
busy_wait(1 / fps - dt_s)
dt_s = (time.perf_counter() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument("--robot", type=str, default="koch", help="Name of the robot provided to the `make_robot(name)` factory function.")
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
parser_record.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_record.add_argument('--root', type=Path, default="data", help='')
parser_record.add_argument('--repo-id', type=str, default="lerobot/test", help='')
parser_record.add_argument('--warmup-time-s', type=int, default=2, help='')
parser_record.add_argument('--episode-time-s', type=int, default=10, help='')
parser_record.add_argument('--num-episodes', type=int, default=50, help='')
parser_record.add_argument('--run-compute-stats', type=int, default=1, help='')
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
parser_replay.add_argument('--root', type=Path, default="data", help='')
parser_replay.add_argument('--repo-id', type=str, default="lerobot/test", help='')
parser_replay.add_argument('--episode', type=int, default=0, help='')
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument('-p', '--pretrained-policy-name-or-path', type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
)
)
parser_policy.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
control_mode = args.mode
robot_name = args.robot
kwargs = vars(args)
del kwargs["mode"]
del kwargs["robot"]
robot = make_robot(robot_name)
if control_mode == "teleoperate":
teleoperate(robot, **kwargs)
elif control_mode == "record_dataset":
record_dataset(robot, **kwargs)
elif control_mode == "replay_episode":
replay_episode(robot, **kwargs)
elif control_mode == "run_policy":
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", args.overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
run_policy(robot, policy, hydra_cfg)

View File

@@ -28,7 +28,7 @@ OR, you want to evaluate a model checkpoint from the LeRobot training script for
```
python lerobot/scripts/eval.py \
-p outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
-p outputs/train/diffusion_pusht/checkpoints/005000 \
eval.n_episodes=10
```
@@ -44,10 +44,8 @@ https://huggingface.co/lerobot/diffusion_pusht/tree/main.
import argparse
import json
import logging
import os
import threading
import time
from contextlib import nullcontext
from copy import deepcopy
from datetime import datetime as dt
from pathlib import Path
@@ -165,10 +163,7 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [
i["is_success"] if (i is not None and "is_success" in i) else False
for i in info["final_info"]
]
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
else:
successes = [False] * env.num_envs
@@ -520,13 +515,12 @@ def eval(
out_dir = (
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
)
os.makedirs(out_dir, exist_ok=True)
if out_dir is None:
raise NotImplementedError()
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
get_safe_torch_device(hydra_cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -545,17 +539,16 @@ def eval(
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
info = eval_policy(
env,
policy,
hydra_cfg.eval.n_episodes,
max_episodes_rendered=hydra_cfg.eval.max_episodes_rendered,
video_dir=Path(out_dir) / "eval",
start_seed=hydra_cfg.seed,
enable_progbar=True,
enable_inner_progbar=True,
)
info = eval_policy(
env,
policy,
hydra_cfg.eval.n_episodes,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
start_seed=hydra_cfg.seed,
enable_progbar=True,
enable_inner_progbar=True,
)
print(info["aggregated"])
# Save info
@@ -567,31 +560,6 @@ def eval(
logging.info("End of eval")
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(
snapshot_download(pretrained_policy_name_or_path, revision=revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path
if __name__ == "__main__":
init_logging()
@@ -626,6 +594,26 @@ if __name__ == "__main__":
if args.pretrained_policy_name_or_path is None:
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
else:
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path, revision=args.revision)
try:
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)

View File

@@ -71,9 +71,9 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict
@@ -144,7 +144,8 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
def push_dataset_to_hub(
data_dir: Path,
input_data_dir: Path,
output_data_dir: Path,
dataset_id: str,
raw_format: str | None,
community_id: str,
@@ -161,34 +162,33 @@ def push_dataset_to_hub(
):
repo_id = f"{community_id}/{dataset_id}"
raw_dir = data_dir / f"{dataset_id}_raw"
out_dir = data_dir / repo_id
meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos"
meta_data_dir = output_data_dir / "meta_data"
videos_dir = output_data_dir / "videos"
tests_out_dir = tests_data_dir / repo_id
tests_meta_data_dir = tests_out_dir / "meta_data"
tests_videos_dir = tests_out_dir / "videos"
if out_dir.exists():
shutil.rmtree(out_dir)
if output_data_dir.exists():
shutil.rmtree(output_data_dir)
if tests_out_dir.exists() and save_tests_to_disk:
shutil.rmtree(tests_out_dir)
if not raw_dir.exists():
download_raw(raw_dir, dataset_id)
if not input_data_dir.exists():
download_raw(input_data_dir, dataset_id)
if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir)
# raw_format = auto_find_raw_format(input_data_dir)
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
# convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
input_data_dir, output_data_dir, fps, video, debug
)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
@@ -202,7 +202,7 @@ def push_dataset_to_hub(
if save_to_disk:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_dir / "train"))
hf_dataset.save_to_disk(str(output_data_dir / "train"))
if not dry_run or save_to_disk:
# mandatory for upload
@@ -236,19 +236,25 @@ def push_dataset_to_hub(
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if not save_to_disk and out_dir.exists():
if not save_to_disk and output_data_dir.exists():
# remove possible temporary files remaining in the output directory
shutil.rmtree(out_dir)
shutil.rmtree(output_data_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-dir",
"--input-data-dir",
type=Path,
required=True,
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw`).",
)
parser.add_argument(
"--output-data-dir",
type=Path,
required=True,
help="Root directory containing output dataset (e.g. `data/lerobot/aloha_mobile_chair` or `data/lerobot/pusht`).",
)
parser.add_argument(
"--dataset-id",

View File

@@ -1,211 +0,0 @@
"""
LCR Auto Configure: This program is used to automatically configure the Low Cost Robot (LCR) for the user.
The program will:
1. Disable all torque motors of provided LCR.
2. Ask the user to move the LCR to the position 1 (see CONFIGURING.md for more details).
3. Record the position of the LCR.
4. Ask the user to move the LCR to the position 2 (see CONFIGURING.md for more details).
5. Record the position of the LCR.
6. Ask the user to move back the LCR to the position 1.
7. Record the position of the LCR.
8. Calculate the offset of the LCR and save it to the configuration file.
It will also enable all appropriate operating modes for the LCR according if the LCR is a puppet or a master.
"""
import argparse
import time
import numpy as np
from lerobot.common.robot_devices.motors.dynamixel import DynamixelBus, OperatingMode, DriveMode, TorqueMode
def pause():
"""
Pause the program until the user presses the enter key.
"""
input("Press Enter to continue...")
def prepare_configuration(arm: DynamixelBus):
"""
Prepare the configuration for the LCR.
:param arm: DynamixelBus
"""
# To be configured, all servos must be in "torque disable" mode
arm.sync_write_torque_enable(TorqueMode.DISABLED.value)
# We need to work with 'extended position mode' (4) for all servos, because in joint mode (1) the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
arm.sync_write_operating_mode(OperatingMode.EXTENDED_POSITION.value, [1, 2, 3, 4, 5])
# Gripper is always 'position control current based' (5)
arm.write_operating_mode(OperatingMode.CURRENT_CONTROLLED_POSITION.value, 6)
# We need to reset the homing offset for all servos
arm.sync_write_homing_offset(0)
# We need to work with 'normal drive mode' (0) for all servos
arm.sync_write_drive_mode(DriveMode.NON_INVERTED.value)
def invert_appropriate_positions(positions: np.array, inverted: list[bool]) -> np.array:
"""
Invert the appropriate positions.
:param positions: numpy array of positions
:param inverted: list of booleans to determine if the position should be inverted
:return: numpy array of inverted positions
"""
for i, invert in enumerate(inverted):
if not invert and positions[i] is not None:
positions[i] = -positions[i]
return positions
def calculate_corrections(positions: np.array, inverted: list[bool]) -> np.array:
"""
Calculate the corrections for the positions.
:param positions: numpy array of positions
:param inverted: list of booleans to determine if the position should be inverted
:return: numpy array of corrections
"""
wanted = wanted_position_1()
correction = invert_appropriate_positions(positions, inverted)
for i in range(len(positions)):
if correction[i] is not None:
if inverted[i]:
correction[i] -= wanted[i]
else:
correction[i] += wanted[i]
return correction
def calculate_nearest_rounded_positions(positions: np.array) -> np.array:
"""
Calculate the nearest rounded positions.
:param positions: numpy array of positions
:return: numpy array of nearest rounded positions
"""
return np.array(
[round(positions[i] / 1024) * 1024 if positions[i] is not None else None for i in range(len(positions))])
def configure_homing(arm: DynamixelBus, inverted: list[bool]) -> np.array:
"""
Configure the homing for the LCR.
:param arm: DynamixelBus
:param inverted: list of booleans to determine if the position should be inverted
"""
# Reset homing offset for the servos
arm.sync_write_homing_offset(0)
# Get the present positions of the servos
present_positions = arm.sync_read_present_position_i32()
nearest_positions = calculate_nearest_rounded_positions(present_positions)
correction = calculate_corrections(nearest_positions, inverted)
# Write the homing offset for the servos
arm.sync_write_homing_offset(correction)
def configure_drive_mode(arm: DynamixelBus):
"""
Configure the drive mode for the LCR.
:param arm: DynamixelBus
:param homing: numpy array of homing
"""
# Get current positions
present_positions = arm.sync_read_present_position_i32()
nearest_positions = calculate_nearest_rounded_positions(present_positions)
# construct 'inverted' list comparing nearest_positions and wanted_position_2
inverted = []
for i in range(len(nearest_positions)):
inverted.append(nearest_positions[i] != wanted_position_2()[i])
# Write the drive mode for the servos
arm.sync_write_drive_mode(
[DriveMode.INVERTED.value if i else DriveMode.NON_INVERTED.value for i in inverted])
return inverted
def wanted_position_1() -> np.array:
"""
The present position wanted in position 1 for the arm
"""
return np.array([0, -1024, 1024, 0, 0, 0])
def wanted_position_2() -> np.array:
"""
The present position wanted in position 2 for the arm
"""
return np.array([1024, 0, 0, 1024, 1024, -1024])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="LCR Auto Configure: This program is used to automatically configure the Low Cost Robot (LCR) for "
"the user.")
parser.add_argument("--port", type=str, required=True, help="The port of the LCR.")
args = parser.parse_args()
arm = DynamixelBus(
args.port, {
1: "x_series",
2: "x_series",
3: "x_series",
4: "x_series",
5: "x_series",
6: "x_series",
}
)
prepare_configuration(arm)
# Ask the user to move the LCR to the position 1
print("Please move the LCR to the position 1")
pause()
configure_homing(arm, [False, False, False, False, False, False])
# Ask the user to move the LCR to the position 2
print("Please move the LCR to the position 2")
pause()
inverted = configure_drive_mode(arm)
# Ask the user to move back the LCR to the position 1
print("Please move back the LCR to the position 1")
pause()
configure_homing(arm, inverted)
print("Configuration done!")
print("Make sure everything is working properly:")
while True:
positions = arm.sync_read_present_position_i32()
print(positions)
time.sleep(1)

View File

@@ -1,20 +0,0 @@
import time
from lerobot.common.robot_devices.robots.aloha import AlohaRobot
import torch
def record_dataset():
robot = AlohaRobot(use_cameras=True)
robot.init_teleop()
while True:
now = time.time()
observation, action = robot.teleop_step(record_data=True)
dt_s = (time.time() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
if __name__ == "__main__":
record_dataset()

View File

@@ -1,19 +0,0 @@
import time
from lerobot.common.robot_devices.robots.aloha import AlohaRobot
import torch
def record_dataset():
robot = AlohaRobot(use_cameras=True)
robot.init_teleop()
while True:
now = time.time()
observation, action = robot.teleop_step(record_data=True)
dt_s = (time.time() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
if __name__ == "__main__":
record_dataset()

View File

@@ -1,47 +0,0 @@
import time
from lerobot.common.robot_devices.robots.aloha import AlohaRobot
import torch
def teleoperate():
robot = AlohaRobot(use_cameras=False)
robot.init_teleop()
while True:
now = time.time()
robot.teleop_step(record_data=False)
dt_s = (time.time() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
def record_teleop_data():
robot = AlohaRobot(use_cameras=True)
robot.init_teleop()
while True:
now = time.time()
observation, action = robot.teleop_step(record_data=True)
dt_s = (time.time() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
def evaluate_policy(policy):
robot = AlohaRobot(use_cameras=True)
observation = robot.init_evaluate()
while True:
now = time.time()
with torch.inference_mode():
action = policy.select_action(observation)
observation, action = robot.step(action, record_data=False)
dt_s = (time.time() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")

View File

@@ -1,20 +0,0 @@
import time
from lerobot.common.robot_devices.robots.aloha import AlohaRobot
def teleoperate():
robot = AlohaRobot(use_cameras=False)
robot.init_teleop()
while True:
now = time.time()
robot.teleop_step(record_data=False)
dt_s = (time.time() - now)
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
if __name__ == "__main__":
teleoperate()

View File

@@ -1,320 +0,0 @@
import argparse
from pathlib import Path
import time
import warnings
import cv2
import numpy as np
from examples.real_robot_example.gym_real_world.robot import Robot
import signal
import sys
# import pyrealsense2 as rs
MAX_LEADER_GRIPPER_RAD = 0.7761942786701344
MAX_LEADER_GRIPPER_POS = 2567
MAX_FOLLOWER_GRIPPER_RAD = 1.6827769243105486
MAX_FOLLOWER_GRIPPER_POS = 3100
MIN_LEADER_GRIPPER_RAD = -0.12732040539450828
MIN_LEADER_GRIPPER_POS = 1984
MIN_FOLLOWER_GRIPPER_RAD = 0.6933593161243099
MIN_FOLLOWER_GRIPPER_POS = 2512
CAMERA_WIDTH = 640
CAMERA_HEIGHT = 480
def convert_gripper_range_from_leader_to_follower(leader_gripper_pos):
follower_gripper_pos = \
(leader_gripper_pos - MIN_LEADER_GRIPPER_POS) \
/ (MAX_LEADER_GRIPPER_POS - MIN_LEADER_GRIPPER_POS) \
* (MAX_FOLLOWER_GRIPPER_POS - MIN_FOLLOWER_GRIPPER_POS) \
+ MIN_FOLLOWER_GRIPPER_POS
return follower_gripper_pos
# alexander koch
# leader_port = "/dev/ttyACM1"
# follower_port = "/dev/ttyACM0"
def disable_torque():
leader_right_port = "/dev/ttyDXL_master_right"
leader_left_port = "/dev/ttyDXL_master_left"
follower_right_port = "/dev/ttyDXL_puppet_right"
follower_left_port = "/dev/ttyDXL_puppet_left"
# starts at 1
all_servo_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
leader_right = Robot(leader_right_port, servo_ids=all_servo_ids)
leader_left = Robot(leader_left_port, servo_ids=all_servo_ids)
follower_right = Robot(follower_right_port, servo_ids=all_servo_ids)
follower_left = Robot(follower_left_port, servo_ids=all_servo_ids)
leader_right._disable_torque()
leader_left._disable_torque()
follower_right._disable_torque()
follower_left._disable_torque()
def teleoperate():
leader_right_port = "/dev/ttyDXL_master_right"
follower_right_port = "/dev/ttyDXL_puppet_right"
leader_left_port = "/dev/ttyDXL_master_left"
follower_left_port = "/dev/ttyDXL_puppet_left"
# starts at 1
all_servo_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
leader_right = Robot(leader_right_port, servo_ids=all_servo_ids)
leader_left = Robot(leader_left_port, servo_ids=all_servo_ids)
follower_right = Robot(follower_right_port, servo_ids=all_servo_ids)
follower_left = Robot(follower_left_port, servo_ids=all_servo_ids)
follower_right._enable_torque()
follower_left._enable_torque()
while True:
now = time.time()
# Prepare to assign the positions of the leader to the follower
follower_right_pos = leader_right.read_position()
follower_left_pos = leader_left.read_position()
# Update the position of the follower gripper to account for the different minimum and maximum range
# position in range [0, 4096[ which corresponds to 4096 bins of 360 degrees
# for all our dynamixel servors
# gripper id=8 has a different range from leader to follower
follower_right_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_right_pos[-1])
follower_left_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_left_pos[-1])
# Assign
follower_right.set_goal_pos(follower_right_pos)
follower_left.set_goal_pos(follower_left_pos)
print(f"Time to send pos: {(time.time() - now) * 1000}")
def capture_frame(camera: cv2.VideoCapture, output_color="rgb"):
# OpenCV acquires frames in BGR format (blue, green red)
ret, frame = camera.read()
if not ret:
raise OSError(f"Camera not found.")
if output_color == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame
def find_camera_ids(out_dir="outputs/find_camera_ids/2024_06_19_cv2_1344"):
save_images = True
max_index_search_range = 60
num_warmup_frames = 4
# Works well
codec = "yuyv"
fps = 30
width = 640
height = 480
# # Works well
# codec = "yuyv"
# fps = 60
# width = 640
# height = 480
# # Works well
# codec = "yuyv"
# fps = 90
# width = 640
# height = 480
# # Works well
# codec = "yuyv"
# fps = 30
# width = 1280
# height = 720
# Doesn't work well (timeout)
# codec = "mjpg"
# fps = 30
# width = 1280
# height = 720
out_dir += f"_{width}x{height}_{fps}_{codec}"
camera_ids = []
for camera_idx in range(max_index_search_range):
camera = cv2.VideoCapture(camera_idx)
is_open = camera.isOpened()
camera.release()
if is_open:
print(f"Camera found at index {camera_idx}")
camera_ids.append(camera_idx)
if len(camera_ids) == 0:
raise OSError("No camera has been found")
# Change camera settings
cameras = []
for camera_idx in camera_ids:
camera = cv2.VideoCapture(camera_idx)
camera.set(cv2.CAP_PROP_FPS, fps)
camera.set(cv2.CAP_PROP_FRAME_WIDTH, width)
camera.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
if codec == "mjpg":
camera.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
actual_fps = camera.get(cv2.CAP_PROP_FPS)
actual_width = camera.get(cv2.CAP_PROP_FRAME_WIDTH)
actual_height = camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
if fps != actual_fps:
warnings.warn(f"{fps=} != {actual_fps=}", stacklevel=1)
if width != actual_width:
warnings.warn(f"{width=} != {actual_width=}", stacklevel=1)
if height != actual_height:
warnings.warn(f"{height=} != {actual_height=}", stacklevel=1)
cameras.append(camera)
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
print("Capturing a few frames to warmup")
for _ in range(num_warmup_frames):
for camera_idx, camera in zip(camera_ids, cameras):
print(f"Capturing camera {camera_idx}")
try:
frame = capture_frame(camera, output_color="bgr" if save_images else "rgb")
time.sleep(0.01)
except OSError as e:
print(e)
time.sleep(0.1)
print("Capturing frames")
try:
while True:
now = time.time()
for camera_idx, camera in zip(camera_ids, cameras):
try:
frame = capture_frame(camera, output_color="bgr" if save_images else "rgb")
except OSError as e:
print(e)
def write_shape(frame):
height, width = frame.shape[:2]
text = f'Width: {width} Height: {height}'
# Define the font, scale, color, and thickness
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
color = (255, 0, 0) # Blue in BGR
thickness = 2
position = (10, height - 10) # 10 pixels from the bottom-left corner
cv2.putText(frame, text, position, font, font_scale, color, thickness)
if save_images:
frame_path = out_dir / f"camera_{camera_idx:02}.png"
print(f"Write to {frame_path}")
write_shape(frame)
cv2.imwrite(str(frame_path), frame)
time.sleep(0.1)
dt_s = (time.time() - now)
dt_ms = dt_s * 1000
freq = 1 / dt_s
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
if save_images:
break
if cv2.waitKey(1) & 0xFF == ord("q"):
break
finally:
# Stop streaming
for camera in cameras:
camera.release()
return camera_ids
def record_data():
leader_right_port = "/dev/ttyDXL_master_right"
follower_right_port = "/dev/ttyDXL_puppet_right"
leader_left_port = "/dev/ttyDXL_master_left"
follower_left_port = "/dev/ttyDXL_puppet_left"
# starts at 1
all_servo_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
leader_right = Robot(leader_right_port, servo_ids=all_servo_ids)
leader_left = Robot(leader_left_port, servo_ids=all_servo_ids)
follower_right = Robot(follower_right_port, servo_ids=all_servo_ids)
follower_left = Robot(follower_left_port, servo_ids=all_servo_ids)
follower_right._enable_torque()
follower_left._enable_torque()
# To get the camera_ids, run: `find_camera_ids()`
camera_high = cv2.VideoCapture(10)
camera_low = cv2.VideoCapture(22)
camera_right_wrist = cv2.VideoCapture(16)
camera_left_wrist = cv2.VideoCapture(4)
if not camera_high.isOpened():
raise OSError("Camera high port can't be accessed.")
if not camera_low.isOpened():
raise OSError("Camera low port can't be accessed.")
if not camera_right_wrist.isOpened():
raise OSError("Camera right_wrist port can't be accessed.")
if not camera_left_wrist.isOpened():
raise OSError("Camera left_wrist port can't be accessed.")
while True:
now = time.time()
frame_high = capture_frame(camera_high)
frame_low = capture_frame(camera_low)
frame_right_wrist = capture_frame(camera_right_wrist)
frame_left_wrist = capture_frame(camera_left_wrist)
# cv2.imshow("high", frame_high)
# cv2.imshow("low", frame_low)
# cv2.imshow("right_wrist", frame_right_wrist)
# cv2.imshow("left_wrist", frame_left_wrist)
# Prepare to assign the positions of the leader to the follower
follower_right_pos = leader_right.read_position()
follower_left_pos = leader_left.read_position()
# Update the position of the follower gripper to account for the different minimum and maximum range
# position in range [0, 4096[ which corresponds to 4096 bins of 360 degrees
# for all our dynamixel servors
# gripper id=8 has a different range from leader to follower
follower_right_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_right_pos[-1])
follower_left_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_left_pos[-1])
# Assign
follower_right.set_goal_pos(follower_right_pos)
follower_left.set_goal_pos(follower_left_pos)
dt_s = (time.time() - now)
dt_ms = dt_s * 1000
freq = 1 / dt_s
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
if cv2.waitKey(1) & 0xFF == ord("q"):
break
camera_high.release()
camera_low.release()
camera_right_wrist.release()
camera_left_wrist.release()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=["teleoperate", "disable_torque", "record_data", "find_camera_ids"], default="teleoperate")
args = parser.parse_args()
if args.mode == "teleoperate":
teleoperate()
elif args.mode == "disable_torque":
disable_torque()
elif args.mode == "record_data":
record_data()
elif args.mode == "find_camera_ids":
find_camera_ids()

View File

@@ -1,338 +0,0 @@
import argparse
from pathlib import Path
import time
import traceback
import cv2
import numpy as np
from examples.real_robot_example.gym_real_world.robot import Robot
import signal
import sys
import pyrealsense2 as rs
MAX_LEADER_GRIPPER_RAD = 0.7761942786701344
MAX_LEADER_GRIPPER_POS = 2567
MAX_FOLLOWER_GRIPPER_RAD = 1.6827769243105486
MAX_FOLLOWER_GRIPPER_POS = 3100
MIN_LEADER_GRIPPER_RAD = -0.12732040539450828
MIN_LEADER_GRIPPER_POS = 1984
MIN_FOLLOWER_GRIPPER_RAD = 0.6933593161243099
MIN_FOLLOWER_GRIPPER_POS = 2512
CAMERA_WIDTH = 640
CAMERA_HEIGHT = 480
def convert_gripper_range_from_leader_to_follower(leader_gripper_pos):
follower_gripper_pos = \
(leader_gripper_pos - MIN_LEADER_GRIPPER_POS) \
/ (MAX_LEADER_GRIPPER_POS - MIN_LEADER_GRIPPER_POS) \
* (MAX_FOLLOWER_GRIPPER_POS - MIN_FOLLOWER_GRIPPER_POS) \
+ MIN_FOLLOWER_GRIPPER_POS
return follower_gripper_pos
# alexander koch
# leader_port = "/dev/ttyACM1"
# follower_port = "/dev/ttyACM0"
def disable_torque():
leader_right_port = "/dev/ttyDXL_master_right"
leader_left_port = "/dev/ttyDXL_master_left"
follower_right_port = "/dev/ttyDXL_puppet_right"
follower_left_port = "/dev/ttyDXL_puppet_left"
# starts at 1
all_servo_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
leader_right = Robot(leader_right_port, servo_ids=all_servo_ids)
leader_left = Robot(leader_left_port, servo_ids=all_servo_ids)
follower_right = Robot(follower_right_port, servo_ids=all_servo_ids)
follower_left = Robot(follower_left_port, servo_ids=all_servo_ids)
leader_right._disable_torque()
leader_left._disable_torque()
follower_right._disable_torque()
follower_left._disable_torque()
def teleoperate():
leader_right_port = "/dev/ttyDXL_master_right"
follower_right_port = "/dev/ttyDXL_puppet_right"
leader_left_port = "/dev/ttyDXL_master_left"
follower_left_port = "/dev/ttyDXL_puppet_left"
# starts at 1
all_servo_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
leader_right = Robot(leader_right_port, servo_ids=all_servo_ids)
leader_left = Robot(leader_left_port, servo_ids=all_servo_ids)
follower_right = Robot(follower_right_port, servo_ids=all_servo_ids)
follower_left = Robot(follower_left_port, servo_ids=all_servo_ids)
follower_right._enable_torque()
follower_left._enable_torque()
while True:
now = time.time()
# Prepare to assign the positions of the leader to the follower
follower_right_pos = leader_right.read_position()
follower_left_pos = leader_left.read_position()
# Update the position of the follower gripper to account for the different minimum and maximum range
# position in range [0, 4096[ which corresponds to 4096 bins of 360 degrees
# for all our dynamixel servors
# gripper id=8 has a different range from leader to follower
follower_right_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_right_pos[-1])
follower_left_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_left_pos[-1])
# Assign
follower_right.set_goal_pos(follower_right_pos)
follower_left.set_goal_pos(follower_left_pos)
print(f"Time to send pos: {(time.time() - now) * 1000}")
def capture_frame(camera: cv2.VideoCapture, width=CAMERA_WIDTH, height=CAMERA_HEIGHT, output_color="rgb"):
# OpenCV acquires frames in BGR format (blue, green red)
ret, frame = camera.read()
if not ret:
raise OSError(f"Camera not found.")
if output_color == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# # Define your crop coordinates (top left corner and bottom right corner)
# x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
# x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
# # Crop the image
# image = image[y1:y2, x1:x2]
# Resize the image
frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
return frame
def write_shape(frame):
height, width = frame.shape[:2]
text = f'Width: {width} Height: {height}'
# Define the font, scale, color, and thickness
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
color = (255, 0, 0) # Blue in BGR
thickness = 2
position = (10, height - 10) # 10 pixels from the bottom-left corner
cv2.putText(frame, text, position, font, font_scale, color, thickness)
def find_camera_ids(out_dir="outputs/find_camera_ids/2024_06_19_1039"):
"""
Install: https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md
List cameras and make sure the firmware is up-to-date: https://dev.intelrealsense.com/docs/firmware-releases-d400
```bash
rs-fw-update -l
> Connected devices:
> 1) [USB] Intel RealSense D405 s/n 128422270109, update serial number: 133323070634, firmware version: 5.16.0.1
> 2) [USB] Intel RealSense D405 s/n 128422271609, update serial number: 130523070758, firmware version: 5.16.0.1
> 3) [USB] Intel RealSense D405 s/n 128422271614, update serial number: 133323070576, firmware version: 5.16.0.1
> 4) [USB] Intel RealSense D405 s/n 128422271393, update serial number: 133323070271, firmware version: 5.16.0.1
```
"""
save_images = False
# enable once, if you reach "Frame didn't arrive" exception
force_hardware_reset = False
# Works well!
# use_depth = False
# fps = 90
# width = 640
# height = 480
# # Works well!
# use_depth = True
# fps = 90
# width = 640
# height = 480
# # Doesn't work well, latency varies too much
# use_depth = True
# fps = 30
# width = 1280
# height = 720
# Works well
use_depth = False
fps = 30
width = 1280
height = 720
out_dir += f"_{width}x{height}_{fps}_depth_{use_depth}"
ctx = rs.context()
serials = []
cameras = []
for device in ctx.query_devices():
print(device)
if force_hardware_reset:
device.hardware_reset()
SERIAL_NUMBER_INDEX = 1
serial_number = device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX))
config = rs.config()
config.enable_device(serial_number)
config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, fps)
if use_depth:
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
pipeline = rs.pipeline()
pipeline.start(config)
serials.append(serial_number)
cameras.append(pipeline)
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
try:
while True:
now = time.time()
for serial, camera in zip(serials, cameras):
# Wait for a coherent pair of frames: depth and color
try:
frames = camera.wait_for_frames()
except RuntimeError as e:
if "Frame didn't arrive" in str(e):
print(f"{e}: Trying hardware_reset. If it still doesn't work, try `force_hardware_reset=True`.")
device.hardware_reset()
traceback.print_exc()
continue
# acquire color image
color_frame = frames.get_color_frame()
if not color_frame:
print("Empty color frame")
continue
# to numpy
image = np.asanyarray(color_frame.get_data())
if save_images:
image_path = out_dir / f"camera_{serial:02}.png"
print(f"Write to {image_path}")
write_shape(image)
cv2.imwrite(str(image_path), image)
if use_depth:
# acquire depth image
depth_frame = frames.get_depth_frame()
if not depth_frame:
print("Empty depth frame")
continue
# to numpy
depth = np.asanyarray(depth_frame.get_data())
if save_images:
# Apply colormap on depth image (image must be converted to 8-bit per pixel first)
depth_image = cv2.applyColorMap(cv2.convertScaleAbs(depth, alpha=0.03), cv2.COLORMAP_JET)
depth_image_path = out_dir / f"camera_{serial:02}_depth.png"
print(f"Write to {depth_image_path}")
write_shape(depth_image)
cv2.imwrite(str(depth_image_path), depth_image)
dt_s = (time.time() - now)
dt_ms = dt_s * 1000
freq = 1 / dt_s
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
if save_images:
break
if cv2.waitKey(1) & 0xFF == ord("q"):
break
finally:
# Stop streaming
for camera in cameras:
camera.stop()
return serials
def record_data():
leader_right_port = "/dev/ttyDXL_master_right"
follower_right_port = "/dev/ttyDXL_puppet_right"
leader_left_port = "/dev/ttyDXL_master_left"
follower_left_port = "/dev/ttyDXL_puppet_left"
# starts at 1
all_servo_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9]
leader_right = Robot(leader_right_port, servo_ids=all_servo_ids)
leader_left = Robot(leader_left_port, servo_ids=all_servo_ids)
follower_right = Robot(follower_right_port, servo_ids=all_servo_ids)
follower_left = Robot(follower_left_port, servo_ids=all_servo_ids)
follower_right._enable_torque()
follower_left._enable_torque()
# To get the camera_ids, run: `find_camera_ids()`
camera_high = cv2.VideoCapture(10)
camera_low = cv2.VideoCapture(22)
camera_right_wrist = cv2.VideoCapture(16)
camera_left_wrist = cv2.VideoCapture(4)
if not camera_high.isOpened():
raise OSError("Camera high port can't be accessed.")
if not camera_low.isOpened():
raise OSError("Camera low port can't be accessed.")
if not camera_right_wrist.isOpened():
raise OSError("Camera right_wrist port can't be accessed.")
if not camera_left_wrist.isOpened():
raise OSError("Camera left_wrist port can't be accessed.")
while True:
now = time.time()
frame_high = capture_frame(camera_high)
frame_low = capture_frame(camera_low)
frame_right_wrist = capture_frame(camera_right_wrist)
frame_left_wrist = capture_frame(camera_left_wrist)
# cv2.imshow("high", frame_high)
# cv2.imshow("low", frame_low)
# cv2.imshow("right_wrist", frame_right_wrist)
# cv2.imshow("left_wrist", frame_left_wrist)
# Prepare to assign the positions of the leader to the follower
follower_right_pos = leader_right.read_position()
follower_left_pos = leader_left.read_position()
# Update the position of the follower gripper to account for the different minimum and maximum range
# position in range [0, 4096[ which corresponds to 4096 bins of 360 degrees
# for all our dynamixel servors
# gripper id=8 has a different range from leader to follower
follower_right_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_right_pos[-1])
follower_left_pos[-1] = convert_gripper_range_from_leader_to_follower(follower_left_pos[-1])
# Assign
follower_right.set_goal_pos(follower_right_pos)
follower_left.set_goal_pos(follower_left_pos)
print(f"Time to send pos: {(time.time() - now) * 1000}")
if cv2.waitKey(1) & 0xFF == ord("q"):
break
camera_high.release()
camera_low.release()
camera_right_wrist.release()
camera_left_wrist.release()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=["teleoperate", "disable_torque", "record_data", "find_camera_ids"], default="teleoperate")
args = parser.parse_args()
if args.mode == "teleoperate":
teleoperate()
elif args.mode == "disable_torque":
disable_torque()
elif args.mode == "record_data":
record_data()
elif args.mode == "find_camera_ids":
find_camera_ids()

View File

@@ -1,10 +0,0 @@
import pyrealsense2 as rs
pipe = rs.pipeline()
profile = pipe.start()
try:
for i in range(0, 100):
frames = pipe.wait_for_frames()
for f in frames:
print(f.profile)
finally:
pipe.stop()

View File

@@ -15,30 +15,25 @@
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
import datasets
import hydra
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.cuda.amp import GradScaler
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
@@ -74,6 +69,7 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
@@ -91,40 +87,21 @@ def make_optimizer_and_scheduler(cfg, policy):
return optimizer, lr_scheduler
def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
):
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"""Returns a dictionary of items for logging."""
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
start_time = time.time()
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
@@ -138,13 +115,31 @@ def update_policy(
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time,
"update_s": time.time() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
return info
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
@@ -216,6 +211,103 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
# sanity check
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
assert first_index == 0, f"{first_index=} is not 0"
assert first_index == episode_data_index["from"][first_episode_idx].item()
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.hf_dataset = hf_dataset
online_dataset.episode_data_index = episode_data_index
else:
# get the starting indices of the new episodes and frames to be added
start_episode_idx = last_episode_idx + 1
start_index = last_index + 1
def shift_indices(episode_index, index):
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
@@ -224,96 +316,35 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging()
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists."
)
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
if cfg.training.online_steps > 0:
raise NotImplementedError("Online training is not implemented yet.")
set_global_seed(cfg.seed)
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
logging.warning("eval.batch_size > 1 not supported for online training steps")
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
logging.info("make_env")
eval_env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step = logger.load_last_training_state(optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
@@ -325,31 +356,27 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step):
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
if cfg.training.save_model and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpont(
step,
logger.save_model(
policy,
optimizer,
lr_scheduler,
identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),
@@ -357,48 +384,32 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("Resume training")
# create dataloader for offline training
if cfg.training.get("drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.training.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=cfg.training.num_workers,
num_workers=4,
batch_size=cfg.training.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
for _ in range(step, cfg.training.offline_steps):
if step == 0:
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.training.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(device, non_blocking=True)
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
@@ -406,27 +417,78 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
if cfg.training.eval_freq > 0:
eval_env.close()
logging.info("End of training")
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
online_step = 0
is_offline = False
for env_step in range(cfg.training.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
policy.eval()
with torch.no_grad():
eval_info = eval_policy(
online_training_env,
policy,
n_episodes=1,
return_episode_data=True,
start_seed=cfg.training.online_env_seed,
enable_progbar=True,
)
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
add_episodes_inplace(
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.training.online_sampling_ratio,
)
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1
eval_env.close()
online_training_env.close()
logging.info("End of training")
if __name__ == "__main__":

View File

@@ -224,8 +224,7 @@ def main():
help=(
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. "
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
)
parser.add_argument(
@@ -246,8 +245,8 @@ def main():
default=0,
help=(
"Save a .rrd file in the directory provided by `--output-dir`. "
"It also deactivates the spawning of a viewer. "
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
"It also deactivates the spawning of a viewer. ",
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
),
)
parser.add_argument(

308
poetry.lock generated
View File

@@ -444,63 +444,63 @@ files = [
[[package]]
name = "coverage"
version = "7.5.3"
version = "7.5.1"
description = "Code coverage measurement for Python"
optional = true
python-versions = ">=3.8"
files = [
{file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"},
{file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"},
{file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"},
{file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"},
{file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"},
{file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"},
{file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"},
{file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"},
{file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"},
{file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"},
{file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"},
{file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"},
{file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"},
{file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"},
{file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"},
{file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"},
{file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"},
{file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"},
{file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"},
{file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"},
{file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"},
{file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"},
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
]
[package.dependencies]
@@ -595,24 +595,6 @@ files = [
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
]
[[package]]
name = "deepdiff"
version = "7.0.1"
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
optional = false
python-versions = ">=3.8"
files = [
{file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"},
{file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"},
]
[package.dependencies]
ordered-set = ">=4.1.0,<4.2.0"
[package.extras]
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
optimize = ["orjson"]
[[package]]
name = "diffusers"
version = "0.27.2"
@@ -805,19 +787,6 @@ files = [
[package.dependencies]
pyarrow = "*"
[[package]]
name = "dynamixel-sdk"
version = "3.7.31"
description = "Dynamixel SDK 3. python package"
optional = true
python-versions = "*"
files = [
{file = "dynamixel_sdk-3.7.31-py3-none-any.whl", hash = "sha256:74e8c112ca6b0b869b196dd8c6a44ffd5dd5c1a3cb9fe2030e9933922406b466"},
]
[package.dependencies]
pyserial = "*"
[[package]]
name = "einops"
version = "0.8.0"
@@ -1106,7 +1075,7 @@ description = ""
optional = true
python-versions = "^3.10"
files = []
develop = false
develop = true
[package.dependencies]
dora-rs = ">=0.3.4"
@@ -1114,11 +1083,8 @@ gymnasium = ">=0.29.1"
pyarrow = ">=12.0.0"
[package.source]
type = "git"
url = "https://github.com/dora-rs/dora-lerobot.git"
reference = "HEAD"
resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79"
subdirectory = "gym_dora"
type = "directory"
url = "gym_dora"
[[package]]
name = "gym-pusht"
@@ -1323,13 +1289,13 @@ files = [
[[package]]
name = "huggingface-hub"
version = "0.23.2"
version = "0.23.0"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
{file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
{file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
{file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
]
[package.dependencies]
@@ -2115,15 +2081,18 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
name = "nodeenv"
version = "1.9.0"
version = "1.8.0"
description = "Node.js virtual environment builder"
optional = true
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
files = [
{file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"},
{file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"},
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
]
[package.dependencies]
setuptools = "*"
[[package]]
name = "numba"
version = "0.59.1"
@@ -2365,14 +2334,13 @@ files = [
[[package]]
name = "nvidia-nvjitlink-cu12"
version = "12.5.40"
version = "12.4.127"
description = "Nvidia JIT LTO Library"
optional = false
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"},
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"},
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
]
[[package]]
@@ -2425,20 +2393,6 @@ numpy = [
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
]
[[package]]
name = "ordered-set"
version = "4.1.0"
description = "An OrderedSet is a custom MutableSet that remembers its order, so that every"
optional = false
python-versions = ">=3.7"
files = [
{file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"},
{file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"},
]
[package.extras]
dev = ["black", "mypy", "pytest"]
[[package]]
name = "packaging"
version = "24.0"
@@ -3015,20 +2969,6 @@ files = [
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pyserial"
version = "3.5"
description = "Python Serial Port Extension"
optional = true
python-versions = "*"
files = [
{file = "pyserial-3.5-py2.py3-none-any.whl", hash = "sha256:c4451db6ba391ca6ca299fb3ec7bae67a5c55dde170964c7a14ceefec02f2cf0"},
{file = "pyserial-3.5.tar.gz", hash = "sha256:3c77e014170dfffbd816e6ffc205e9842efb10be9f58ec16d3e8675b4925cddb"},
]
[package.extras]
cp2110 = ["hidapi"]
[[package]]
name = "pysocks"
version = "1.7.1"
@@ -3256,13 +3196,13 @@ files = [
[[package]]
name = "requests"
version = "2.32.3"
version = "2.32.1"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
files = [
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
{file = "requests-2.32.1-py3-none-any.whl", hash = "sha256:21ac9465cdf8c1650fe1ecde8a71669a93d4e6f147550483a2967d08396a56a5"},
{file = "requests-2.32.1.tar.gz", hash = "sha256:eb97e87e64c79e64e5b8ac75cee9dd1f97f49e289b083ee6be96268930725685"},
]
[package.dependencies]
@@ -3278,16 +3218,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rerun-sdk"
version = "0.16.1"
version = "0.16.0"
description = "The Rerun Logging SDK"
optional = false
python-versions = "<3.13,>=3.8"
files = [
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"},
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"},
{file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"},
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"},
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"},
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"},
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"},
{file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"},
]
[package.dependencies]
@@ -3467,36 +3407,36 @@ test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-
[[package]]
name = "scipy"
version = "1.13.1"
version = "1.13.0"
description = "Fundamental algorithms for scientific computing in Python"
optional = true
python-versions = ">=3.9"
files = [
{file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"},
{file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"},
{file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"},
{file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"},
{file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"},
{file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"},
{file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"},
{file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"},
{file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"},
{file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"},
{file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"},
{file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"},
{file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"},
{file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"},
{file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"},
{file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"},
{file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"},
{file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"},
{file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"},
{file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"},
{file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"},
{file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"},
{file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"},
{file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"},
{file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"},
{file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
{file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
{file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
{file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
{file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
{file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
{file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
{file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
{file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
{file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
{file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
{file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
{file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
{file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
{file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
{file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
{file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
{file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
{file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"},
{file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"},
{file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"},
{file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"},
{file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"},
{file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"},
{file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
]
[package.dependencies]
@@ -3509,13 +3449,13 @@ test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "po
[[package]]
name = "sentry-sdk"
version = "2.3.1"
version = "2.2.1"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = ">=3.6"
files = [
{file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"},
{file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"},
{file = "sentry_sdk-2.2.1-py2.py3-none-any.whl", hash = "sha256:7d617a1b30e80c41f3b542347651fcf90bb0a36f3a398be58b4f06b79c8d85bc"},
{file = "sentry_sdk-2.2.1.tar.gz", hash = "sha256:8aa2ec825724d8d9d645cab68e6034928b1a6a148503af3e361db3fa6401183f"},
]
[package.dependencies]
@@ -3764,17 +3704,17 @@ files = [
[[package]]
name = "sympy"
version = "1.12.1"
version = "1.12"
description = "Computer algebra system (CAS) in Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
]
[package.dependencies]
mpmath = ">=1.1.0,<1.4.0"
mpmath = ">=0.19"
[[package]]
name = "tbb"
@@ -3805,13 +3745,13 @@ tests = ["pytest", "pytest-cov"]
[[package]]
name = "tifffile"
version = "2024.5.22"
version = "2024.5.10"
description = "Read and write TIFF files"
optional = true
python-versions = ">=3.9"
files = [
{file = "tifffile-2024.5.22-py3-none-any.whl", hash = "sha256:e281781c15d7d197d7e12749849c965651413aa905f97a48b0f84bd90a3b4c6f"},
{file = "tifffile-2024.5.22.tar.gz", hash = "sha256:3a105801d1b86d55692a98812a170c39d3f0447aeacb1d94635d38077cb328c4"},
{file = "tifffile-2024.5.10-py3-none-any.whl", hash = "sha256:4154f091aa24d4e75bfad9ab2d5424a68c70e67b8220188066dc61946d4551bd"},
{file = "tifffile-2024.5.10.tar.gz", hash = "sha256:aa1e1b12be952ab20717d6848bd6d4a5ee88d2aa319f1152bff4354ad728ec86"},
]
[package.dependencies]
@@ -3967,13 +3907,13 @@ tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
[[package]]
name = "typing-extensions"
version = "4.12.0"
version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"},
{file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"},
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
]
[[package]]
@@ -4288,13 +4228,13 @@ multidict = ">=4.0"
[[package]]
name = "zarr"
version = "2.18.2"
version = "2.18.1"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"},
{file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"},
{file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"},
{file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"},
]
[package.dependencies]
@@ -4309,13 +4249,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
[[package]]
name = "zipp"
version = "3.19.0"
version = "3.18.2"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
{file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"},
{file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"},
{file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
{file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
]
[package.extras]
@@ -4334,4 +4274,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "390d4f67713be47f0d44446d8198e577d28040deddea431937a9eb0fdf40adbd"
content-hash = "ea4e8207316a8ec8a4b95d6a89cf488c8733a8e7ab43e5f669c889ee87f3bef3"

View File

@@ -41,12 +41,12 @@ numba = ">=0.59.0"
torch = "^2.2.1"
opencv-python = ">=4.9.0"
diffusers = "^0.27.2"
torchvision = ">=0.17.1"
torchvision = ">=0.18.0"
h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
gymnasium = ">=0.29.1"
cmake = ">=3.29.0.1"
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
gym-dora = { path = "gym_dora", optional = true, develop = true}
gym-pusht = { version = ">=0.1.3", optional = true}
gym-xarm = { version = ">=0.1.1", optional = true}
gym-aloha = { version = ">=0.1.1", optional = true}
@@ -59,9 +59,6 @@ imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1"
dynamixel-sdk = {version = "^3.7.31", optional = true}
[tool.poetry.extras]

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
size 5104

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
size 31688

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
size 34928

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
size 5104

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
size 30808

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
size 33608

View File

@@ -75,16 +75,15 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
dataset.delta_timestamps = None
batch = next(iter(dataloader))
obs = {}
for k in batch:
if k.startswith("observation"):
obs[k] = batch[k]
if "n_action_steps" in cfg.policy:
actions_queue = cfg.policy.n_action_steps
else:
actions_queue = cfg.policy.n_action_repeats
obs = {
k: batch[k]
for k in batch
if k in ["observation.image", "observation.images.top", "observation.state"]
}
actions_queue = (
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
)
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
return output_dict, grad_stats, param_stats, actions
@@ -115,8 +114,6 @@ if __name__ == "__main__":
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@@ -16,7 +16,6 @@
import json
import logging
from copy import deepcopy
from itertools import chain
from pathlib import Path
import einops
@@ -26,34 +25,26 @@ from datasets import Dataset
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
compute_stats,
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import (
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
)
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, repo_id, policy_name):
"""
Tests that:
- we can create a dataset with the factory.
- for a commonly used set of data keys, the data dimensions are correct.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
@@ -114,39 +105,6 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}"
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
@@ -357,31 +315,3 @@ def test_backward_compatibility(repo_id):
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
def test_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):
data_a = torch.rand(30, dtype=torch.float32)
data_b = torch.rand(20, dtype=torch.float32)
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))

View File

@@ -29,8 +29,8 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text
def _run_script(path, args=None):
subprocess.run([sys.executable, path] + args if args is not None else [], check=True)
def _run_script(path):
subprocess.run([sys.executable, path], check=True)
def _read_file(path):
@@ -45,11 +45,11 @@ def test_example_1():
@require_package("gym_pusht")
def test_examples_basic2_basic3_advanced1():
def test_examples_2_through_4():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
Calculate the validation loss with advanced example 1, check the outputs.
Calculate the validation loss with example 4, check the outputs.
"""
### Test example 3
@@ -97,7 +97,7 @@ def test_examples_basic2_basic3_advanced1():
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
## Test example 4
file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py")
file_contents = _read_file("examples/4_calculate_validation_loss.py")
# Run on a single example from the last episode, use CPU, and use the local model.
file_contents = _find_and_replace(
@@ -126,22 +126,3 @@ def test_examples_basic2_basic3_advanced1():
# Restore stdout to its original state
sys.stdout = sys.__stdout__
assert "Average loss on validation set" in printed_output
def test_real_world_recording():
path = "examples/real_robot_example/record_training_data.py"
_run_script(
path,
[
"--data_dir",
"outputs/examples",
"--repo-id",
"real_world_debug",
"--num-episodes",
"2",
"--num-frames",
"10",
"--mock-robot",
],
)
assert Path("outputs/examples/real_world_debug/video/episode_0.mp4").exists()

View File

@@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.scripts.save_policy_to_safetensor import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -72,8 +72,6 @@ def test_get_policy_and_config_classes(policy_name: str):
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
],
)
@require_env
@@ -86,9 +84,6 @@ def test_policy(env_name, policy_name, extra_overrides):
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
@@ -140,7 +135,7 @@ def test_policy(env_name, policy_name, extra_overrides):
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
num_workers=4,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",
@@ -296,8 +291,6 @@ def test_normalize(insert_temporal_dim):
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't

View File

@@ -1,90 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
def test_drop_n_first_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert list(sampler) == [1, 4, 5]
def test_drop_n_last_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert list(sampler) == [0, 3, 4]
def test_episode_indices_to_use():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert list(sampler) == [0, 1, 3, 4, 5]
def test_shuffle():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}

View File

@@ -11,24 +11,22 @@ from lerobot.common.datasets.utils import (
hf_transform_to_torch,
reset_episode_index,
)
from lerobot.common.utils.utils import (
get_global_random_state,
seeded_context,
set_global_random_state,
set_global_seed,
from lerobot.common.utils.utils import seeded_context, set_global_seed
@pytest.mark.parametrize(
"rand_fn",
(
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else []
),
)
# Random generation functions for testing the seeding and random state get/set.
rand_fns = [
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
if torch.cuda.is_available():
rand_fns.append(lambda: torch.rand(1, device="cuda"))
@pytest.mark.parametrize("rand_fn", rand_fns)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
@@ -48,15 +46,6 @@ def test_seeding(rand_fn: Callable[[], int]):
assert c_ == c
def test_get_set_random_state():
"""Check that getting the random state, then setting it results in the same random number generation."""
random_state_dict = get_global_random_state()
rand_numbers = [rand_fn() for rand_fn in rand_fns]
set_global_random_state(random_state_dict)
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
assert rand_numbers_ == rand_numbers
def test_calculate_episode_data_index():
dataset = Dataset.from_dict(
{