diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index a9b733d4b..4063e395f 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,11 +1,15 @@
-# What does this PR do?
+## What this does
+Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
Examples:
-- Fixes # (issue)
-- Adds new dataset
-- Optimizes something
+| Title | Label |
+|----------------------|-----------------|
+| Fixes #[issue] | (π Bug) |
+| Adds new dataset | (ποΈ Dataset) |
+| Optimizes something | (β‘οΈ Performance) |
-## How was it tested?
+## How it was tested
+Explain/show how you tested your changes.
Examples:
- Added `test_something` in `tests/test_stuff.py`.
@@ -13,6 +17,7 @@ Examples:
- Optimized `some_function`, it now runs X times faster than previously.
## How to checkout & try? (for the reviewer)
+Provide a simple way for the reviewer to try out your changes.
Examples:
```bash
@@ -22,11 +27,8 @@ DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
python lerobot/scripts/train.py --some.option=true
```
-## Before submitting
-Please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
-
-
-## Who can review?
-
-Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
+## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
+**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
+
+**Note**: Before submitting this PR, please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f86e4f8e8..45feabdc8 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -57,6 +57,38 @@ jobs:
&& rm -rf tests/outputs outputs
+ pytest-minimal:
+ name: Pytest (minimal install)
+ runs-on: ubuntu-latest
+ env:
+ DATA_DIR: tests/data
+ MUJOCO_GL: egl
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Install poetry
+ run: |
+ pipx install poetry && poetry config virtualenvs.in-project true
+ echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
+
+ - name: Set up Python 3.10
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+
+ - name: Install poetry dependencies
+ run: |
+ poetry install --extras "test"
+
+ - name: Test with pytest
+ run: |
+ pytest tests -v --cov=./lerobot --durations=0 \
+ -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
+ -W ignore::UserWarning:torch.utils.data.dataloader:558 \
+ -W ignore::UserWarning:gymnasium.utils.env_checker:247 \
+ && rm -rf tests/outputs outputs
+
+
end-to-end:
name: End-to-end
runs-on: ubuntu-latest
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 93b11ca9d..913fcb5db 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -18,7 +18,7 @@ repos:
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.4.2
+ rev: v0.4.3
hooks:
- id: ruff
args: [--fix]
diff --git a/Makefile b/Makefile
index 07aa4e977..a0163f94d 100644
--- a/Makefile
+++ b/Makefile
@@ -22,9 +22,8 @@ test-end-to-end:
${MAKE} test-act-ete-eval
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
- # TODO(rcadene, alexander-soare): enable end-to-end tests for tdmpc
- # ${MAKE} test-tdmpc-ete-train
- # ${MAKE} test-tdmpc-ete-eval
+ ${MAKE} test-tdmpc-ete-train
+ ${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval
test-act-ete-train:
@@ -80,7 +79,7 @@ test-tdmpc-ete-train:
policy=tdmpc \
env=xarm \
env.task=XarmLift-v0 \
- dataset_repo_id=lerobot/xarm_lift_medium_replay \
+ dataset_repo_id=lerobot/xarm_lift_medium \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=2 \
diff --git a/README.md b/README.md
index ff4aa4980..b1c441cb5 100644
--- a/README.md
+++ b/README.md
@@ -29,15 +29,15 @@
---
-π€ LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
+π€ LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
π€ LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
-π€ LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
+π€ LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulation environments to get started without assembling a robot. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
-π€ LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
+π€ LeRobot hosts pretrained models and datasets on this Hugging Face community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
-#### Examples of pretrained models and environments
+#### Examples of pretrained models on simulation environments
@@ -54,10 +54,11 @@
### Acknowledgment
-- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
-- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
-- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
-- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
+- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
+- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
+- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
+- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
+
## Installation
@@ -86,15 +87,18 @@ For instance, to install π€ LeRobot with aloha and pusht, use:
pip install ".[aloha, pusht]"
```
-To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
+To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
```bash
wandb login
```
+(note: you will also need to enable WandB in the configuration. See below.)
+
## Walkthrough
```
.
+βββ examples # contains demonstration examples, start here to learn about LeRobot
βββ lerobot
| βββ configs # contains hydra yaml files with all options that you can override in the command line
| | βββ default.yaml # selected by default, it loads pusht environment and diffusion policy
@@ -103,69 +107,84 @@ wandb login
| βββ common # contains classes and utilities
| | βββ datasets # various datasets of human demonstrations: aloha, pusht, xarm
| | βββ envs # various sim environments: aloha, pusht, xarm
-| | βββ policies # various policies: act, diffusion, tdmpc
-| βββ scripts # contains functions to execute via command line
-| βββ visualize_dataset.py # load a dataset and render its demonstrations
-| βββ eval.py # load policy and evaluate it on an environment
-| βββ train.py # train a policy via imitation learning and/or reinforcement learning
+| | βββ policies # various policies: act, diffusion, tdmpc
+| | βββ utils # various utilities
+| βββ scripts # contains functions to execute via command line
+| βββ eval.py # load policy and evaluate it on an environment
+| βββ train.py # train a policy via imitation learning and/or reinforcement learning
+| βββ push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
+| βββ visualize_dataset.py # load a dataset and render its demonstrations
βββ outputs # contains results of scripts execution: logs, videos, model checkpoints
-βββ .github
-| βββ workflows
-| βββ test.yml # defines install settings for continuous integration and specifies end-to-end tests
βββ tests # contains pytest utilities for continuous integration
-
```
### Visualize datasets
-Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
+Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
-Or you can achieve the same result by executing our script from the command line:
+You can also locally visualize episodes from a dataset by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
-env=pusht \
-hydra.run.dir=outputs/visualize_dataset/example
-# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
+ --repo-id lerobot/pusht \
+ --episode-index 0
```
+It will open `rerun.io` and display the camera streams, robot states and actions, like this:
+
+https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
+
+
+Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.
+
### Evaluate a pretrained policy
-Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
+Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
-Or you can achieve the same result by executing our script from the command line:
+We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
```bash
python lerobot/scripts/eval.py \
--p lerobot/diffusion_pusht \
-eval_episodes=10 \
-hydra.run.dir=outputs/eval/example_hub
+ -p lerobot/diffusion_pusht \
+ eval.n_episodes=10 \
+ eval.batch_size=10
```
-After training your own policy, you can also re-evaluate the checkpoints with:
-
+Note: After training your own policy, you can re-evaluate the checkpoints with:
```bash
python lerobot/scripts/eval.py \
--p PATH/TO/TRAIN/OUTPUT/FOLDER \
-eval_episodes=10 \
-hydra.run.dir=outputs/eval/example_dir
+ -p PATH/TO/TRAIN/OUTPUT/FOLDER
```
See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy
-Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed.
+Check out [example 3](./examples/3_train_policy.py) that illustrates how to start training a model.
-In general, you can use our training script to easily train any policy on any environment:
+In general, you can use our training script to easily train any policy. Here is an example of training the ACT policy on trajectories collected by humans on the Aloha simulation environment for the insertion task:
```bash
python lerobot/scripts/train.py \
-env=aloha \
-task=sim_insertion \
-repo_id=lerobot/aloha_sim_insertion_scripted \
-policy=act \
-hydra.run.dir=outputs/train/aloha_act
+ policy=act \
+ env=aloha \
+ env.task=AlohaInsertion-v0 \
+ dataset_repo_id=lerobot/aloha_sim_insertion_human \
```
-After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
+The experiment directory is automatically generated and will show up in yellow in your terminal. It looks like `outputs/train/2024-05-05/20-21-12_aloha_act_default`. You can manually specify an experiment directory by adding this argument to the `train.py` python command:
+```bash
+ hydra.run.dir=your/new/experiment/dir
+```
+
+To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:
+
+```bash
+ wandb.enable=true
+```
+
+A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
+
+
+
+Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. After training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python lerobot/scripts/eval.py --help` for more instructions.
## Contribute
@@ -173,98 +192,40 @@ If you would like to contribute to π€ LeRobot, please check out our [contribut
### Add a new dataset
-```python
-# TODO(rcadene, AdilZouitine): rewrite this section
-```
-
-To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
+To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
-Then you can upload it to the hub with:
+Then move your dataset folder in `data` directory (e.g. `data/aloha_ping_pong`), and push your dataset to the hub with:
```bash
-HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
---repo-type dataset \
---revision v1.0
+python lerobot/scripts/push_dataset_to_hub.py \
+--data-dir data \
+--dataset-id aloha_ping_ping \
+--raw-format aloha_hdf5 \
+--community-id lerobot
```
-You will need to set the corresponding version as a default argument in your dataset class:
-```python
- version: str | None = "v1.1",
-```
-See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
+See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
-For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
-```bash
-HF_USER=lerobot
-DATASET=pusht
-```
+If your dataset format is not supported, implement your own in `lerobot/common/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py).
-If you want to improve an existing dataset, you can download it locally with:
-```bash
-mkdir -p data/$DATASET
-HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
---repo-type dataset \
---local-dir data/$DATASET \
---local-dir-use-symlinks=False \
---revision v1.0
-```
-
-Iterate on your code and dataset with:
-```bash
-DATA_DIR=data python train.py
-```
-
-Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
-```bash
-HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
---repo-type dataset \
---revision v1.1 \
---delete "*"
-```
-
-Then you will need to set the corresponding version as a default argument in your dataset class:
-```python
- version: str | None = "v1.1",
-```
-See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
-
-
-Finally, you might want to mock the dataset if you need to update the unit tests as well:
-```bash
-python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
-```
### Add a pretrained policy
-```python
-# TODO(rcadene, alexander-soare): rewrite this section
-```
-
-Once you have trained a policy you may upload it to the HuggingFace hub.
-
-Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
-
-Secondly, assuming you have trained a policy, you need the following (which should all be in any of the subdirectories of `checkpoints` in your training output folder, if you've used the LeRobot training script):
+Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
+You first need to find the checkpoint located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). It should contain:
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
-- `model.safetensors`: The `torch.nn.Module` parameters saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
-- `config.yaml`: This is the consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
-
-To upload these to the hub, run the following with a desired revision ID.
+- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
+- `config.yaml`: A consolidated Hydra training configuration containing the policy, environment, and dataset configs. The policy configuration should match `config.json` exactly. The environment config is useful for anyone who wants to evaluate your policy. The dataset config just serves as a paper trail for reproducibility.
+To upload these to the hub, run the following:
```bash
-huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR --revision $REVISION_ID
+huggingface-cli upload ${hf_user}/${repo_name} path/to/checkpoint/dir
```
-If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
-
-```bash
-huggingface-cli upload $HUB_ID PATH/TO/OUTPUT/DIR
-```
-
-See `eval.py` for an example of how a user may use your policy.
+See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
### Improve your code with profiling
@@ -291,9 +252,14 @@ with profile(
# insert code to profile, potentially whole body of eval_policy function
```
-```bash
-python lerobot/scripts/eval.py \
---config outputs/pusht/.hydra/config.yaml \
-pretrained_model_path=outputs/pusht/model.pt \
-eval_episodes=7
+## Citation
+
+If you want, you can cite this work with:
+```
+@misc{cadene2024lerobot,
+ author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
+ title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
+ howpublished = "\url{https://github.com/huggingface/lerobot}",
+ year = {2024}
+}
```
diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile
index ab78937cf..a2823dc20 100644
--- a/docker/lerobot-gpu/Dockerfile
+++ b/docker/lerobot-gpu/Dockerfile
@@ -7,6 +7,11 @@ ARG DEBIAN_FRONTEND=noninteractive
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake \
+ git git-lfs openssh-client \
+ nano vim \
+ htop atop nvtop \
+ sed gawk grep curl wget \
+ tcpdump sysstat screen \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
@@ -18,7 +23,8 @@ ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot
-COPY . /lerobot
+RUN git lfs install
+RUN git clone https://github.com/huggingface/lerobot.git
WORKDIR /lerobot
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py
index f86199c52..3846926a6 100644
--- a/examples/1_load_lerobot_dataset.py
+++ b/examples/1_load_lerobot_dataset.py
@@ -14,6 +14,7 @@ The script ends with examples of how to batch process data using PyTorch's DataL
"""
from pathlib import Path
+from pprint import pprint
import imageio
import torch
@@ -21,39 +22,36 @@ import torch
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
-print("List of available datasets", lerobot.available_datasets)
-# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
-# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
-# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
+print("List of available datasets:")
+pprint(lerobot.available_datasets)
+# Let's take one for this example
repo_id = "lerobot/pusht"
-# You can easily load a dataset from a Hugging Face repositery
+# You can easily load a dataset from a Hugging Face repository
dataset = LeRobotDataset(repo_id)
-# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
-# TODO(rcadene): update to make the print pretty
-print(f"{dataset=}")
-print(f"{dataset.hf_dataset=}")
+# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
+# (see https://huggingface.co/docs/datasets/index for more information).
+print(dataset)
+print(dataset.hf_dataset)
-# and provides additional utilities for robotics and compatibility with pytorch
-print(f"number of samples/frames: {dataset.num_samples=}")
-print(f"number of episodes: {dataset.num_episodes=}")
-print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
+# And provides additional utilities for robotics and compatibility with Pytorch
+print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
-print(f"keys to access images from cameras: {dataset.image_keys=}")
+print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
# Access frame indexes associated to first episode
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
-# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
-# Here we grab all the image frames.
+# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
+# with the latter, like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
-# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
-# To visualize them, we convert to uint8 range [0,255]
+# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
+# them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
@@ -62,9 +60,9 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
-# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
-# Our datasets can load previous and future frames for each key/modality,
-# using timestamps differences with the current loaded frame. For instance:
+# For many machine learning applications we need to load the history of past observations or trajectories of
+# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
+# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0],
@@ -74,12 +72,12 @@ delta_timestamps = {
"action": [t / dataset.fps for t in range(64)],
}
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
-print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
+print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
-print(f"{dataset[0]['action'].shape=}") # (64,c)
+print(f"{dataset[0]['action'].shape=}\n") # (64,c)
-# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
-# because they are just PyTorch datasets.
+# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
+# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py
index 001054f3b..8481f0f54 100644
--- a/examples/2_evaluate_pretrained_policy.py
+++ b/examples/2_evaluate_pretrained_policy.py
@@ -5,23 +5,108 @@ training outputs directory. In the latter case, you might want to run examples/3
from pathlib import Path
+import gym_pusht # noqa: F401
+import gymnasium as gym
+import imageio
+import numpy
+import torch
from huggingface_hub import snapshot_download
-from lerobot.scripts.eval import eval
+from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
-# Get a pretrained policy from the hub.
-pretrained_policy_name = "lerobot/diffusion_pusht"
-pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))
+# Create a directory to store the video of the evaluation
+output_directory = Path("outputs/eval/example_pusht_diffusion")
+output_directory.mkdir(parents=True, exist_ok=True)
+
+device = torch.device("cuda")
+
+# Download the diffusion policy for pusht environment
+pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
-# Override some config parameters to do with evaluation.
-overrides = [
- "eval.n_episodes=10",
- "eval.batch_size=10",
- "device=cuda",
-]
+policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
+policy.eval()
+policy.to(device)
-# Evaluate the policy and save the outputs including metrics and videos.
-# TODO(rcadene, alexander-soare): dont call eval, but add the minimal code snippet to rollout
-eval(pretrained_policy_path=pretrained_policy_path)
+# Initialize evaluation environment to render two observation types:
+# an image of the scene and state/position of the agent. The environment
+# also automatically stops running after 300 interactions/steps.
+env = gym.make(
+ "gym_pusht/PushT-v0",
+ obs_type="pixels_agent_pos",
+ max_episode_steps=300,
+)
+
+# Reset the policy and environmens to prepare for rollout
+policy.reset()
+numpy_observation, info = env.reset(seed=42)
+
+# Prepare to collect every rewards and all the frames of the episode,
+# from initial state to final state.
+rewards = []
+frames = []
+
+# Render frame of the initial state
+frames.append(env.render())
+
+step = 0
+done = False
+while not done:
+ # Prepare observation for the policy running in Pytorch
+ state = torch.from_numpy(numpy_observation["agent_pos"])
+ image = torch.from_numpy(numpy_observation["pixels"])
+
+ # Convert to float32 with image from channel first in [0,255]
+ # to channel last in [0,1]
+ state = state.to(torch.float32)
+ image = image.to(torch.float32) / 255
+ image = image.permute(2, 0, 1)
+
+ # Send data tensors from CPU to GPU
+ state = state.to(device, non_blocking=True)
+ image = image.to(device, non_blocking=True)
+
+ # Add extra (empty) batch dimension, required to forward the policy
+ state = state.unsqueeze(0)
+ image = image.unsqueeze(0)
+
+ # Create the policy input dictionary
+ observation = {
+ "observation.state": state,
+ "observation.image": image,
+ }
+
+ # Predict the next action with respect to the current observation
+ with torch.inference_mode():
+ action = policy.select_action(observation)
+
+ # Prepare the action for the environment
+ numpy_action = action.squeeze(0).to("cpu").numpy()
+
+ # Step through the environment and receive a new observation
+ numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
+ print(f"{step=} {reward=} {terminated=}")
+
+ # Keep track of all the rewards and frames
+ rewards.append(reward)
+ frames.append(env.render())
+
+ # The rollout is considered done when the success state is reach (i.e. terminated is True),
+ # or the maximum number of iterations is reached (i.e. truncated is True)
+ done = terminated | truncated | done
+ step += 1
+
+if terminated:
+ print("Success!")
+else:
+ print("Failure!")
+
+# Get the speed of environment (i.e. its number of frames per second).
+fps = env.metadata["render_fps"]
+
+# Encode all frames into a mp4 video.
+video_path = output_directory / "rollout.mp4"
+imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
+
+print(f"Video of the evaluation is available in '{video_path}'.")
diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py
index 69e3d34cc..c5ce0d184 100644
--- a/examples/3_train_policy.py
+++ b/examples/3_train_policy.py
@@ -4,36 +4,42 @@ Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
-import os
from pathlib import Path
import torch
-from omegaconf import OmegaConf
-from lerobot.common.datasets.factory import make_dataset
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
-from lerobot.common.utils.utils import init_hydra_config
+# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion")
-os.makedirs(output_directory, exist_ok=True)
+output_directory.mkdir(parents=True, exist_ok=True)
-# Number of offline training steps (we'll only do offline training for this example.
+# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000
device = torch.device("cuda")
log_freq = 250
# Set up the dataset.
-hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
-dataset = make_dataset(hydra_cfg)
+delta_timestamps = {
+ # Load the previous image and state at -0.1 seconds before current frame,
+ # then load current image and state corresponding to 0.0 second.
+ "observation.image": [-0.1, 0.0],
+ "observation.state": [-0.1, 0.0],
+ # Load the previous action (-0.1), the next action to be executed (0.0),
+ # and 14 future actions with a 0.1 seconds spacing. All these actions will be
+ # used to supervise the policy.
+ "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
+}
+dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
-# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train()
policy.to(device)
@@ -69,7 +75,5 @@ while not done:
done = True
break
-# Save the policy.
+# Save a policy checkpoint.
policy.save_pretrained(output_directory)
-# Save the Hydra configuration so we have the environment configuration for eval.
-OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
diff --git a/lerobot/__init__.py b/lerobot/__init__.py
index ed8c220a5..e188bc525 100644
--- a/lerobot/__init__.py
+++ b/lerobot/__init__.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
@@ -85,13 +100,6 @@ available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
)
-# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API
-available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
-
-available_datasets = list(
- itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
-)
-
available_policies = [
"act",
"diffusion",
diff --git a/lerobot/__version__.py b/lerobot/__version__.py
index 6232b699d..d12aafaa9 100644
--- a/lerobot/__version__.py
+++ b/lerobot/__version__.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""To enable `lerobot.__version__`"""
from importlib.metadata import PackageNotFoundError, version
diff --git a/lerobot/common/datasets/_video_benchmark/README.md b/lerobot/common/datasets/_video_benchmark/README.md
index 10e8d12f1..d802163fb 100644
--- a/lerobot/common/datasets/_video_benchmark/README.md
+++ b/lerobot/common/datasets/_video_benchmark/README.md
@@ -37,16 +37,16 @@ How to decode videos?
## Variables
**Image content**
-We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this bechmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
+We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this benchmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
**Requested timestamps**
-In this benchmark, we focus on the loading time of random access, so we are not interested about sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
+In this benchmark, we focus on the loading time of random access, so we are not interested in sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
- `single_frame`: 1 frame,
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
**Data augmentations**
-We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robusts (e.g. robust to color changes, compression, etc.).
+We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
## Results
diff --git a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
index 85d48fcfd..8be251dc1 100644
--- a/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
+++ b/lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import random
import shutil
diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py
index 93aec158e..6836b6ff7 100644
--- a/lerobot/common/datasets/factory.py
+++ b/lerobot/common/datasets/factory.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import logging
import torch
diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py
index ffc1b1b66..881f253bb 100644
--- a/lerobot/common/datasets/lerobot_dataset.py
+++ b/lerobot/common/datasets/lerobot_dataset.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
from pathlib import Path
@@ -72,6 +87,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def fps(self) -> int:
+ """Frames per second used during data collection."""
return self.info["fps"]
@property
@@ -86,15 +102,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.hf_dataset.features
@property
- def image_keys(self) -> list[str]:
- image_keys = []
+ def camera_keys(self) -> list[str]:
+ """Keys to access image and video stream from cameras."""
+ keys = []
for key, feats in self.hf_dataset.features.items():
- if isinstance(feats, datasets.Image):
- image_keys.append(key)
- return image_keys + self.video_frame_keys
+ if isinstance(feats, (datasets.Image, VideoFrame)):
+ keys.append(key)
+ return keys
@property
- def video_frame_keys(self):
+ def video_frame_keys(self) -> list[str]:
+ """Keys to access video frames that requires to be decoded into images.
+
+ Note: It is empty if the dataset contains images only,
+ or equal to `self.cameras` if the dataset contains videos only,
+ or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
+ """
video_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, VideoFrame):
@@ -103,10 +126,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
+ """Number of possible samples in the dataset.
+
+ This is equivalent to the number of frames in the dataset minus n_end_keyframes_dropped.
+ """
return len(self.index)
@property
def num_episodes(self) -> int:
+ """Number of episodes."""
return len(self.hf_dataset.unique("episode_index"))
@property
@@ -146,6 +174,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}(\n"
+ f" Repository ID: '{self.repo_id}',\n"
+ f" Version: '{self.version}',\n"
+ f" Split: '{self.split}',\n"
+ f" Number of Samples: {self.num_samples},\n"
+ f" Number of Episodes: {self.num_episodes},\n"
+ f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
+ f" Recorded Frames per Second: {self.fps},\n"
+ f" Camera Keys: {self.camera_keys},\n"
+ f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
+ f" Transformations: {self.transform},\n"
+ f")"
+ )
+
@classmethod
def from_preloaded(
cls,
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py
index 2f5326508..33b4c9745 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
index d26f3d236..232fd0558 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets.
diff --git a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py
index 1561fb886..a118b7e78 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# imagecodecs/numcodecs.py
# Copyright (c) 2021-2022, Christoph Gohlke
diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
index 694304f1d..4efadc9e0 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
"""
@@ -142,12 +157,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
- image_keys = [key for key in data_dict if "observation.images." in key]
- for image_key in image_keys:
+ keys = [key for key in data_dict if "observation.images." in key]
+ for key in keys:
if video:
- features[image_key] = VideoFrame()
+ features[key] = VideoFrame()
else:
- features[image_key] = Image()
+ features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py
index a7a952fb9..ec2966582 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from copy import deepcopy
from math import ceil
diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
index 0c3a8d19c..8133a36af 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy"""
import shutil
diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
index 008287506..cab2bdc52 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface"""
import logging
diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py
index 1b12c0b7e..4feb1dcfd 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/utils.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
index 686edf4ca..899ebdde7 100644
--- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
+++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
import pickle
diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py
index 96b8fbbc9..5cdd5f7c0 100644
--- a/lerobot/common/datasets/utils.py
+++ b/lerobot/common/datasets/utils.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
from pathlib import Path
diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py
index 0252be2ee..edfca918e 100644
--- a/lerobot/common/datasets/video_utils.py
+++ b/lerobot/common/datasets/video_utils.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import logging
import subprocess
import warnings
diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py
index c5fd46711..83f94cfea 100644
--- a/lerobot/common/envs/factory.py
+++ b/lerobot/common/envs/factory.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import importlib
import gymnasium as gym
diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py
index 5370d3857..8fce03698 100644
--- a/lerobot/common/envs/utils.py
+++ b/lerobot/common/envs/utils.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import einops
import numpy as np
import torch
diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py
index 832eaa7e8..109f69511 100644
--- a/lerobot/common/logger.py
+++ b/lerobot/common/logger.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# TODO(rcadene, alexander-soare): clean this file
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py"""
@@ -17,9 +32,10 @@ def log_output_dir(out_dir):
def cfg_to_group(cfg, return_list=False):
- """Return a wandb-safe group name for logging. Optionally returns group name as list."""
- # lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
+ """Return a group name for logging. Optionally returns group name as list."""
lst = [
+ f"policy:{cfg.policy.name}",
+ f"dataset:{cfg.dataset_repo_id}",
f"env:{cfg.env.name}",
f"seed:{cfg.seed}",
]
@@ -81,9 +97,9 @@ class Logger:
# Also save the full Hydra config for the env configuration.
OmegaConf.save(self._cfg, save_dir / "config.yaml")
if self._wandb and not self._disable_wandb_artifact:
- # note wandb artifact does not accept ":" in its name
+ # note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
- self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
+ f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="model",
)
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
@@ -93,9 +109,10 @@ class Logger:
self._buffer_dir.mkdir(parents=True, exist_ok=True)
fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp)
- if self._wandb:
+ if self._wandb and not self._disable_wandb_artifact:
+ # note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(
- self._group + "-" + str(self._seed) + "-" + str(identifier),
+ f"{self._group.replace(':', '_').replace('/', '_')}-{self._seed}-{identifier}",
type="buffer",
)
artifact.add_file(fp)
@@ -113,6 +130,11 @@ class Logger:
assert mode in {"train", "eval"}
if self._wandb is not None:
for k, v in d.items():
+ if not isinstance(v, (int, float, str)):
+ logging.warning(
+ f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
+ )
+ continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py
index a3980b14d..95374f4d5 100644
--- a/lerobot/common/policies/act/configuration_act.py
+++ b/lerobot/common/policies/act/configuration_act.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from dataclasses import dataclass, field
@@ -51,8 +66,12 @@ class ACTConfig:
documentation in the policy class).
latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
- use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
- environment step.
+ temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (Ξ±) for ensembling
+ actions for a given time step over multiple policy invocations. Updates are calculated as:
+ xβ»β = Ξ±xβ»βββ + (1-Ξ±)xβ. Note that the ACT paper and original ACT code describes a different
+ parameter here: they refer to a weighting scheme wα΅’ = exp(-mβ
i) and set m = 0.01. With our
+ formulation, this is equivalent to Ξ± = exp(-0.01) β 0.99. When this parameter is provided, we
+ require `n_action_steps == 1` (since we need to query the policy every step anyway).
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@@ -100,6 +119,9 @@ class ACTConfig:
dim_feedforward: int = 3200
feedforward_activation: str = "relu"
n_encoder_layers: int = 4
+ # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
+ # that means only the first layer is used. Here we match the original implementation by setting this to 1.
+ # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: int = 1
# VAE.
use_vae: bool = True
@@ -107,7 +129,7 @@ class ACTConfig:
n_vae_encoder_layers: int = 4
# Inference.
- use_temporal_aggregation: bool = False
+ temporal_ensemble_momentum: float | None = None
# Training and loss computation.
dropout: float = 0.1
@@ -119,8 +141,11 @@ class ACTConfig:
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
- if self.use_temporal_aggregation:
- raise NotImplementedError("Temporal aggregation is not yet implemented.")
+ if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
+ raise NotImplementedError(
+ "`n_action_steps` must be 1 when using temporal ensembling. This is "
+ "because the policy needs to be queried every step to compute the ensembled action."
+ )
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
@@ -130,10 +155,3 @@ class ACTConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
- # Check that there is only one image.
- # TODO(alexander-soare): generalize this to multiple images.
- if (
- sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
- or "observation.images.top" not in self.input_shapes
- ):
- raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py
index 5ff25fea2..72ebdd7ad 100644
--- a/lerobot/common/policies/act/modeling_act.py
+++ b/lerobot/common/policies/act/modeling_act.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
@@ -46,7 +61,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
super().__init__()
if config is None:
config = ACTConfig()
- self.config = config
+ self.config: ACTConfig = config
+
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
@@ -56,11 +72,18 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
+
self.model = ACT(config)
+ self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
+
+ self.reset()
+
def reset(self):
"""This should be called whenever the environment is reset."""
- if self.config.n_action_steps is not None:
+ if self.config.temporal_ensemble_momentum is not None:
+ self._ensembled_actions = None
+ else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
@@ -71,37 +94,56 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
- assert "observation.images.top" in batch
- assert "observation.state" in batch
-
self.eval()
batch = self.normalize_inputs(batch)
- self._stack_images(batch)
+ batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
+ # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
+ # the first action.
+ if self.config.temporal_ensemble_momentum is not None:
+ actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
+ actions = self.unnormalize_outputs({"action": actions})["action"]
+ if self._ensembled_actions is None:
+ # Initializes `self._ensembled_action` to the sequence of actions predicted during the first
+ # time step of the episode.
+ self._ensembled_actions = actions.clone()
+ else:
+ # self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
+ # the EMA update for those entries.
+ alpha = self.config.temporal_ensemble_momentum
+ self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
+ # The last action, which has no prior moving average, needs to get concatenated onto the end.
+ self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
+ # "Consume" the first action.
+ action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
+ return action
+
+ # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
+ # querying the policy.
if len(self._action_queue) == 0:
- # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
- # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
- actions = self.model(batch)[0][: self.config.n_action_steps]
+ actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
+ # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
+ # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
+ batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
- self._stack_images(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
- loss_dict = {"l1_loss": l1_loss}
+ loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dββ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
@@ -110,28 +152,13 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
- loss_dict["kld_loss"] = mean_kld
+ loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
return loss_dict
- def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
- """Stacks all the images in a batch and puts them in a new key: "observation.images".
-
- This function expects `batch` to have (at least):
- {
- "observation.state": (B, state_dim) batch of robot states.
- "observation.images.{name}": (B, C, H, W) tensor of images.
- }
- """
- # Stack images in the order dictated by input_shapes.
- batch["observation.images"] = torch.stack(
- [batch[k] for k in self.config.input_shapes if k.startswith("observation.images.")],
- dim=-4,
- )
-
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
@@ -161,10 +188,10 @@ class ACT(nn.Module):
β encoder β β β βTransf.β β
β β β β βencoderβ β
βββββ²ββββββ β β β β β
- β β β βββββ²ββββ β
- β β β β β
- inputs βββββββΌββββββ β
- β β
+ β β β ββ²βββ²ββ²ββ β
+ β β β β β β β
+ inputs βββββββΌβββ β image emb. β
+ β state emb. β
βββββββββββββββββββββββββ
"""
@@ -306,18 +333,18 @@ class ACT(nn.Module):
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
- encoder_in = torch.cat(all_cam_features, axis=3)
- cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
+ encoder_in = torch.cat(all_cam_features, axis=-1)
+ cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
- robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
- latent_embed = self.encoder_latent_input_proj(latent_sample)
+ robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
+ latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
- encoder_in.flatten(2).permute(2, 0, 1),
+ einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
pos_embed = torch.cat(
diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py
index 73fabefad..632f6cd69 100644
--- a/lerobot/common/policies/diffusion/configuration_diffusion.py
+++ b/lerobot/common/policies/diffusion/configuration_diffusion.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from dataclasses import dataclass, field
@@ -51,6 +67,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
+ noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
@@ -64,6 +81,9 @@ class DiffusionConfig:
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
+ do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
+ `LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
+ to False as the original Diffusion Policy implementation does the same.
"""
# Inputs / output structure.
@@ -107,6 +127,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
+ noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
@@ -118,23 +139,39 @@ class DiffusionConfig:
# Inference
num_inference_steps: int | None = None
+ # Loss computation
+ do_mask_loss_for_padding: bool = False
+
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
+ # There should only be one image key.
+ image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
+ if len(image_keys) != 1:
+ raise ValueError(
+ f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
+ )
+ image_key = next(iter(image_keys))
if (
- self.crop_shape[0] > self.input_shapes["observation.image"][1]
- or self.crop_shape[1] > self.input_shapes["observation.image"][2]
+ self.crop_shape[0] > self.input_shapes[image_key][1]
+ or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
- f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
- f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
- '`input_shapes["observation.image"]`.'
+ f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
+ f"for `crop_shape` and {self.input_shapes[image_key]} for "
+ "`input_shapes[{image_key}]`."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
+ supported_noise_schedulers = ["DDPM", "DDIM"]
+ if self.noise_scheduler_type not in supported_noise_schedulers:
+ raise ValueError(
+ f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
+ f"Got {self.noise_scheduler_type}."
+ )
diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py
index f5f64d806..2ae03f221 100644
--- a/lerobot/common/policies/diffusion/modeling_diffusion.py
+++ b/lerobot/common/policies/diffusion/modeling_diffusion.py
@@ -1,8 +1,24 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare):
- - Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
+ - Make compatible with multiple image keys.
"""
import math
@@ -10,12 +26,13 @@ from collections import deque
from typing import Callable
import einops
+import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
+from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
-from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
@@ -66,10 +83,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config)
+ image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
+ # Note: This check is covered in the post-init of the config but have a sanity check just in case.
+ if len(image_keys) != 1:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
+ )
+ self.input_image_key = image_keys[0]
+
+ self.reset()
+
def reset(self):
- """
- Clear observation and action queues. Should be called on `env.reset()`
- """
+ """Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
@@ -98,16 +123,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
- assert "observation.image" in batch
- assert "observation.state" in batch
-
batch = self.normalize_inputs(batch)
+ batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
- batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
+ batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@@ -121,11 +144,25 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
+ batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
+def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
+ """
+ Factory for noise scheduler instances of the requested type. All kwargs are passed
+ to the scheduler.
+ """
+ if name == "DDPM":
+ return DDPMScheduler(**kwargs)
+ elif name == "DDIM":
+ return DDIMScheduler(**kwargs)
+ else:
+ raise ValueError(f"Unsupported noise scheduler type {name}")
+
+
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
@@ -138,12 +175,12 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps,
)
- self.noise_scheduler = DDPMScheduler(
+ self.noise_scheduler = _make_noise_scheduler(
+ config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
- variance_type="fixed_small",
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
@@ -185,13 +222,12 @@ class DiffusionModel(nn.Module):
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
- This function expects `batch` to have (at least):
+ This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
- assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
@@ -268,13 +304,84 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
- if "action_is_pad" in batch:
+ if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
+class SpatialSoftmax(nn.Module):
+ """
+ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
+ (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
+
+ At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
+ of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
+
+ Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
+ -----------------------------------------------------
+ | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
+ | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
+ | ... | ... | ... | ... |
+ | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
+ -----------------------------------------------------
+ This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
+ product with the coordinates (120x2) to get expected points of maximal activation (512x2).
+
+ The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
+ provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
+ linear mapping (in_channels, H, W) -> (num_kp, H, W).
+ """
+
+ def __init__(self, input_shape, num_kp=None):
+ """
+ Args:
+ input_shape (list): (C, H, W) input feature map shape.
+ num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
+ """
+ super().__init__()
+
+ assert len(input_shape) == 3
+ self._in_c, self._in_h, self._in_w = input_shape
+
+ if num_kp is not None:
+ self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
+ self._out_c = num_kp
+ else:
+ self.nets = None
+ self._out_c = self._in_c
+
+ # we could use torch.linspace directly but that seems to behave slightly differently than numpy
+ # and causes a small degradation in pc_success of pre-trained models.
+ pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
+ pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
+ pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
+ # register as buffer so it's moved to the correct device.
+ self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
+
+ def forward(self, features: Tensor) -> Tensor:
+ """
+ Args:
+ features: (B, C, H, W) input feature maps.
+ Returns:
+ (B, K, 2) image-space coordinates of keypoints.
+ """
+ if self.nets is not None:
+ features = self.nets(features)
+
+ # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
+ features = features.reshape(-1, self._in_h * self._in_w)
+ # 2d softmax normalization
+ attention = F.softmax(features, dim=-1)
+ # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
+ expected_xy = attention @ self.pos_grid
+ # reshape to [B, K, 2]
+ feature_keypoints = expected_xy.view(-1, self._out_c, 2)
+
+ return feature_keypoints
+
+
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
@@ -315,11 +422,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
+ # The dummy input should take the number of image channels from `config.input_shapes` and it should
+ # use the height and width from `config.crop_shape`.
+ image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
+ assert len(image_keys) == 1
+ image_key = image_keys[0]
+ dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode():
- feat_map_shape = tuple(
- self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
- )
- self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
+ dummy_feature_map = self.backbone(dummy_input)
+ feature_map_shape = tuple(dummy_feature_map.shape[1:])
+ self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py
index 4819ca809..4c124b617 100644
--- a/lerobot/common/policies/factory.py
+++ b/lerobot/common/policies/factory.py
@@ -1,4 +1,20 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import inspect
+import logging
from omegaconf import DictConfig, OmegaConf
@@ -8,9 +24,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
- assert set(hydra_cfg.policy).issuperset(
- expected_kwargs
- ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
+ if not set(hydra_cfg.policy).issuperset(expected_kwargs):
+ logging.warning(
+ f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
+ )
policy_cfg = policy_cfg_class(
**{
k: v
@@ -62,11 +79,18 @@ def make_policy(
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
+ policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None:
- policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
+ # Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats)
else:
- policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
+ # Load a pretrained policy and override the config if needed (for example, if there are inference-time
+ # hyperparameters that we want to vary).
+ # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
+ # weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
+ # make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
+ policy = policy_cls(policy_cfg)
+ policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(get_safe_torch_device(hydra_cfg.device))
diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py
index ab57c8ba2..d638c5416 100644
--- a/lerobot/common/policies/normalize.py
+++ b/lerobot/common/policies/normalize.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import torch
from torch import Tensor, nn
diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py
index 5749c6a87..38738a909 100644
--- a/lerobot/common/policies/policy_protocol.py
+++ b/lerobot/common/policies/policy_protocol.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
@@ -38,7 +53,8 @@ class Policy(Protocol):
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
- Returns a dictionary with "loss" and maybe other information.
+ Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
+ other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]):
diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py
index 82e3a5075..cf76fb08a 100644
--- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py
+++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from dataclasses import dataclass, field
@@ -47,7 +63,7 @@ class TDMPCConfig:
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
gaussian_mean_momentum: Momentum (Ξ±) used for EMA updates of the mean parameter ΞΌ of the gaussian
- paramters optimized in CEM. Updates are calculated as ΞΌβ» β Ξ±ΞΌβ» + (1-Ξ±)ΞΌ.
+ parameters optimized in CEM. Updates are calculated as ΞΌβ» β Ξ±ΞΌβ» + (1-Ξ±)ΞΌ.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
@@ -131,12 +147,18 @@ class TDMPCConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
- if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
+ # There should only be one image key.
+ image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
+ if len(image_keys) != 1:
+ raise ValueError(
+ f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
+ )
+ image_key = next(iter(image_keys))
+ if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
- "Only square images are handled now. Got image shape "
- f"{self.input_shapes['observation.image']}."
+ f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(
diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py
index 1fba43d08..7c873bf23 100644
--- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py
+++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Implementation of Finetuning Offline World Models in the Real World.
The comments in this code may sometimes refer to these references:
@@ -96,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats
)
- def save(self, fp):
- """Save state dict of TOLD model to filepath."""
- torch.save(self.state_dict(), fp)
+ image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
+ # Note: This check is covered in the post-init of the config but have a sanity check just in case.
+ assert len(image_keys) == 1
+ self.input_image_key = image_keys[0]
- def load(self, fp):
- """Load a saved state dict from filepath into current agent."""
- self.load_state_dict(torch.load(fp))
+ self.reset()
def reset(self):
"""
@@ -121,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations."""
- assert "observation.image" in batch
- assert "observation.state" in batch
-
batch = self.normalize_inputs(batch)
+ batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -303,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
+ batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
info = {}
- # TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
- batch_size = batch["index"].shape[0]
-
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
@@ -337,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
+ batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py
index b23c13366..5a62daa2a 100644
--- a/lerobot/common/policies/utils.py
+++ b/lerobot/common/policies/utils.py
@@ -1,9 +1,28 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import torch
from torch import nn
def populate_queues(queues, batch):
for key in batch:
+ # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
+ # queues have the keys they want).
+ if key not in queues:
+ continue
if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen:
diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py
index 642e0ff17..cd5f82450 100644
--- a/lerobot/common/utils/import_utils.py
+++ b/lerobot/common/utils/import_utils.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import importlib
import logging
diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py
index 5d727bd74..b85f17c7a 100644
--- a/lerobot/common/utils/io_utils.py
+++ b/lerobot/common/utils/io_utils.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import warnings
import imageio
diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py
index 9d0ddd986..d62507b59 100644
--- a/lerobot/common/utils/utils.py
+++ b/lerobot/common/utils/utils.py
@@ -1,8 +1,25 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import logging
import os.path as osp
import random
+from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
+from typing import Generator
import hydra
import numpy as np
@@ -39,6 +56,31 @@ def set_global_seed(seed):
torch.cuda.manual_seed_all(seed)
+@contextmanager
+def seeded_context(seed: int) -> Generator[None, None, None]:
+ """Set the seed when entering a context, and restore the prior random state at exit.
+
+ Example usage:
+
+ ```
+ a = random.random() # produces some random number
+ with seeded_context(1337):
+ b = random.random() # produces some other random number
+ c = random.random() # produces yet another random number, but the same it would have if we never made `b`
+ ```
+ """
+ random_state = random.getstate()
+ np_random_state = np.random.get_state()
+ torch_random_state = torch.random.get_rng_state()
+ torch_cuda_random_state = torch.cuda.random.get_rng_state()
+ set_global_seed(seed)
+ yield None
+ random.setstate(random_state)
+ np.random.set_state(np_random_state)
+ torch.random.set_rng_state(torch_random_state)
+ torch.cuda.random.set_rng_state(torch_cuda_random_state)
+
+
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml
index 37695858e..ae36b3e23 100644
--- a/lerobot/configs/default.yaml
+++ b/lerobot/configs/default.yaml
@@ -25,7 +25,7 @@ training:
eval_freq: ???
save_freq: ???
log_freq: 250
- save_model: false
+ save_model: true
eval:
n_episodes: 1
@@ -35,7 +35,7 @@ eval:
use_async_envs: false
wandb:
- enable: true
+ enable: false
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot
diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml
index a49a97f81..7a12dcc2c 100644
--- a/lerobot/configs/policy/act.yaml
+++ b/lerobot/configs/policy/act.yaml
@@ -3,6 +3,12 @@
seed: 1000
dataset_repo_id: lerobot/aloha_sim_insertion_human
+override_dataset_stats:
+ observation.images.top:
+ # stats from imagenet, since we use a pretrained vision model
+ mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
+ std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
+
training:
offline_steps: 80000
online_steps: 0
@@ -18,12 +24,6 @@ training:
grad_clip_norm: 10
online_steps_between_rollouts: 1
- override_dataset_stats:
- observation.images.top:
- # stats from imagenet, since we use a pretrained vision model
- mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
- std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
-
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
@@ -66,6 +66,9 @@ policy:
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
+ # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
+ # that means only the first layer is used. Here we match the original implementation by setting this to 1.
+ # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
@@ -73,7 +76,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
- use_temporal_aggregation: false
+ temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml
index e77a40e3a..0dcff592a 100644
--- a/lerobot/configs/policy/diffusion.yaml
+++ b/lerobot/configs/policy/diffusion.yaml
@@ -7,6 +7,20 @@
seed: 100000
dataset_repo_id: lerobot/pusht
+override_dataset_stats:
+ # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
+ observation.image:
+ mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
+ std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
+ # TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
+ # from the original codebase, but we should remove these and train our own pretrained model
+ observation.state:
+ min: [13.456424, 32.938293]
+ max: [496.14618, 510.9579]
+ action:
+ min: [12.0, 25.0]
+ max: [511.0, 511.0]
+
training:
offline_steps: 200000
online_steps: 0
@@ -44,20 +58,6 @@ eval:
n_episodes: 50
batch_size: 50
-override_dataset_stats:
- # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
- observation.image:
- mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
- std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
- # TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
- # from the original codebase, but we should remove these and train our own pretrained model
- observation.state:
- min: [13.456424, 32.938293]
- max: [496.14618, 510.9579]
- action:
- min: [12.0, 25.0]
- max: [511.0, 511.0]
-
policy:
name: diffusion
@@ -95,6 +95,7 @@ policy:
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
+ noise_scheduler_type: DDPM
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
@@ -105,3 +106,6 @@ policy:
# Inference
num_inference_steps: 100
+
+ # Loss computation
+ do_mask_loss_for_padding: false
diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml
index eb89033bc..7e736850e 100644
--- a/lerobot/configs/policy/tdmpc.yaml
+++ b/lerobot/configs/policy/tdmpc.yaml
@@ -1,7 +1,7 @@
# @package _global_
seed: 1
-dataset_repo_id: lerobot/xarm_lift_medium_replay
+dataset_repo_id: lerobot/xarm_lift_medium
training:
offline_steps: 25000
diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py
index e4ea4260c..4d8b48504 100644
--- a/lerobot/scripts/display_sys_info.py
+++ b/lerobot/scripts/display_sys_info.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import platform
import huggingface_hub
diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py
index 2a61aaea8..9c95633a1 100644
--- a/lerobot/scripts/eval.py
+++ b/lerobot/scripts/eval.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Evaluate a policy on an environment by running rollouts and computing metrics.
Usage examples:
@@ -583,17 +598,18 @@ if __name__ == "__main__":
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
- except HFValidationError:
- logging.warning(
- "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID. "
- "Treating it as a local directory."
- )
- except RepositoryNotFoundError:
- logging.warning(
- "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub. Treating "
- "it as a local directory."
- )
- pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
+ except (HFValidationError, RepositoryNotFoundError) as e:
+ if isinstance(e, HFValidationError):
+ error_message = (
+ "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
+ )
+ else:
+ error_message = (
+ "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
+ )
+
+ logging.warning(f"{error_message} Treating it as a local directory.")
+ pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py
index ca8c46004..16d890a79 100644
--- a/lerobot/scripts/push_dataset_to_hub.py
+++ b/lerobot/scripts/push_dataset_to_hub.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
@@ -60,7 +75,7 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
-from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict
@@ -252,7 +267,7 @@ def main():
parser.add_argument(
"--revision",
type=str,
- default="v1.2",
+ default=CODEBASE_VERSION,
help="Codebase version used to generate the dataset.",
)
parser.add_argument(
diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py
index 6cbc82652..7ca7a0b3c 100644
--- a/lerobot/scripts/train.py
+++ b/lerobot/scripts/train.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import logging
import time
from copy import deepcopy
@@ -8,7 +23,7 @@ import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
-from diffusers.optimization import get_scheduler
+from omegaconf import DictConfig
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -55,6 +70,8 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
+ from diffusers.optimization import get_scheduler
+
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
@@ -71,6 +88,7 @@ def make_optimizer_and_scheduler(cfg, policy):
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
+ """Returns a dictionary of items for logging."""
start_time = time.time()
policy.train()
output_dict = policy.forward(batch)
@@ -98,6 +116,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time,
+ **{k: v for k, v in output_dict.items() if k != "loss"},
}
return info
@@ -121,7 +140,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name)
-def log_train_info(logger, info, step, cfg, dataset, is_offline):
+def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
@@ -289,7 +308,7 @@ def add_episodes_inplace(
sampler.num_samples = len(concat_dataset)
-def train(cfg: dict, out_dir=None, job_name=None):
+def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
@@ -336,7 +355,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
- def _maybe_eval_and_maybe_save(step):
+ def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info = eval_policy(
@@ -392,9 +411,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
- # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
- # step + 1.
- _maybe_eval_and_maybe_save(step + 1)
+ # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
+ # so we pass in step + 1.
+ evaluate_and_checkpoint_if_needed(step + 1)
step += 1
@@ -460,9 +479,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
- # Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
- # in step + 1.
- _maybe_eval_and_maybe_save(step + 1)
+ # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
+ # so we pass in step + 1.
+ evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1
diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py
index 44acd4169..58da6a47e 100644
--- a/lerobot/scripts/visualize_dataset.py
+++ b/lerobot/scripts/visualize_dataset.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
@@ -32,7 +47,7 @@ local$ rerun lerobot_pusht_episode_0.rrd
```
- Visualize data stored on a distant machine through streaming:
-(You need to forward the websocket port to the distant machine, with
+(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python lerobot/scripts/visualize_dataset.py \
@@ -47,6 +62,7 @@ local$ rerun ws://localhost:9087
"""
import argparse
+import gc
import logging
import time
from pathlib import Path
@@ -115,15 +131,17 @@ def visualize_dataset(
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
+
+ # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
+ # when iterating on a dataloader with `num_workers` > 0
+ # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
+ gc.collect()
+
if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun")
- if num_workers > 0:
- # TODO(rcadene): fix data workers hanging when `rr.init` is called
- logging.warning("If data loader is hanging, try `--num-workers 0`.")
-
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):
@@ -131,7 +149,7 @@ def visualize_dataset(
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
- for key in dataset.image_keys:
+ for key in dataset.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
@@ -196,7 +214,7 @@ def main():
parser.add_argument(
"--num-workers",
type=int,
- default=0,
+ default=4,
help="Number of processes of Dataloader for loading the data.",
)
parser.add_argument(
diff --git a/media/wandb.png b/media/wandb.png
new file mode 100644
index 000000000..8adc3d2ae
Binary files /dev/null and b/media/wandb.png differ
diff --git a/poetry.lock b/poetry.lock
index 616f4a6a4..e0b27f159 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -4,7 +4,7 @@
name = "absl-py"
version = "2.1.0"
description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"},
@@ -131,17 +131,6 @@ files = [
{file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"},
]
-[[package]]
-name = "appdirs"
-version = "1.4.4"
-description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
-optional = false
-python-versions = "*"
-files = [
- {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"},
- {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"},
-]
-
[[package]]
name = "asciitree"
version = "0.3.3"
@@ -455,63 +444,63 @@ files = [
[[package]]
name = "coverage"
-version = "7.5.0"
+version = "7.5.1"
description = "Code coverage measurement for Python"
optional = true
python-versions = ">=3.8"
files = [
- {file = "coverage-7.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:432949a32c3e3f820af808db1833d6d1631664d53dd3ce487aa25d574e18ad1c"},
- {file = "coverage-7.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2bd7065249703cbeb6d4ce679c734bef0ee69baa7bff9724361ada04a15b7e3b"},
- {file = "coverage-7.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbfe6389c5522b99768a93d89aca52ef92310a96b99782973b9d11e80511f932"},
- {file = "coverage-7.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39793731182c4be939b4be0cdecde074b833f6171313cf53481f869937129ed3"},
- {file = "coverage-7.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85a5dbe1ba1bf38d6c63b6d2c42132d45cbee6d9f0c51b52c59aa4afba057517"},
- {file = "coverage-7.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:357754dcdfd811462a725e7501a9b4556388e8ecf66e79df6f4b988fa3d0b39a"},
- {file = "coverage-7.5.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a81eb64feded34f40c8986869a2f764f0fe2db58c0530d3a4afbcde50f314880"},
- {file = "coverage-7.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:51431d0abbed3a868e967f8257c5faf283d41ec882f58413cf295a389bb22e58"},
- {file = "coverage-7.5.0-cp310-cp310-win32.whl", hash = "sha256:f609ebcb0242d84b7adeee2b06c11a2ddaec5464d21888b2c8255f5fd6a98ae4"},
- {file = "coverage-7.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:6782cd6216fab5a83216cc39f13ebe30adfac2fa72688c5a4d8d180cd52e8f6a"},
- {file = "coverage-7.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e768d870801f68c74c2b669fc909839660180c366501d4cc4b87efd6b0eee375"},
- {file = "coverage-7.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:84921b10aeb2dd453247fd10de22907984eaf80901b578a5cf0bb1e279a587cb"},
- {file = "coverage-7.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:710c62b6e35a9a766b99b15cdc56d5aeda0914edae8bb467e9c355f75d14ee95"},
- {file = "coverage-7.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c379cdd3efc0658e652a14112d51a7668f6bfca7445c5a10dee7eabecabba19d"},
- {file = "coverage-7.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fea9d3ca80bcf17edb2c08a4704259dadac196fe5e9274067e7a20511fad1743"},
- {file = "coverage-7.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:41327143c5b1d715f5f98a397608f90ab9ebba606ae4e6f3389c2145410c52b1"},
- {file = "coverage-7.5.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:565b2e82d0968c977e0b0f7cbf25fd06d78d4856289abc79694c8edcce6eb2de"},
- {file = "coverage-7.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cf3539007202ebfe03923128fedfdd245db5860a36810136ad95a564a2fdffff"},
- {file = "coverage-7.5.0-cp311-cp311-win32.whl", hash = "sha256:bf0b4b8d9caa8d64df838e0f8dcf68fb570c5733b726d1494b87f3da85db3a2d"},
- {file = "coverage-7.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:9c6384cc90e37cfb60435bbbe0488444e54b98700f727f16f64d8bfda0b84656"},
- {file = "coverage-7.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fed7a72d54bd52f4aeb6c6e951f363903bd7d70bc1cad64dd1f087980d309ab9"},
- {file = "coverage-7.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cbe6581fcff7c8e262eb574244f81f5faaea539e712a058e6707a9d272fe5b64"},
- {file = "coverage-7.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad97ec0da94b378e593ef532b980c15e377df9b9608c7c6da3506953182398af"},
- {file = "coverage-7.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd4bacd62aa2f1a1627352fe68885d6ee694bdaebb16038b6e680f2924a9b2cc"},
- {file = "coverage-7.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:adf032b6c105881f9d77fa17d9eebe0ad1f9bfb2ad25777811f97c5362aa07f2"},
- {file = "coverage-7.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4ba01d9ba112b55bfa4b24808ec431197bb34f09f66f7cb4fd0258ff9d3711b1"},
- {file = "coverage-7.5.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:f0bfe42523893c188e9616d853c47685e1c575fe25f737adf473d0405dcfa7eb"},
- {file = "coverage-7.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a9a7ef30a1b02547c1b23fa9a5564f03c9982fc71eb2ecb7f98c96d7a0db5cf2"},
- {file = "coverage-7.5.0-cp312-cp312-win32.whl", hash = "sha256:3c2b77f295edb9fcdb6a250f83e6481c679335ca7e6e4a955e4290350f2d22a4"},
- {file = "coverage-7.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:427e1e627b0963ac02d7c8730ca6d935df10280d230508c0ba059505e9233475"},
- {file = "coverage-7.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9dd88fce54abbdbf4c42fb1fea0e498973d07816f24c0e27a1ecaf91883ce69e"},
- {file = "coverage-7.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a898c11dca8f8c97b467138004a30133974aacd572818c383596f8d5b2eb04a9"},
- {file = "coverage-7.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07dfdd492d645eea1bd70fb1d6febdcf47db178b0d99161d8e4eed18e7f62fe7"},
- {file = "coverage-7.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3d117890b6eee85887b1eed41eefe2e598ad6e40523d9f94c4c4b213258e4a4"},
- {file = "coverage-7.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6afd2e84e7da40fe23ca588379f815fb6dbbb1b757c883935ed11647205111cb"},
- {file = "coverage-7.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a9960dd1891b2ddf13a7fe45339cd59ecee3abb6b8326d8b932d0c5da208104f"},
- {file = "coverage-7.5.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ced268e82af993d7801a9db2dbc1d2322e786c5dc76295d8e89473d46c6b84d4"},
- {file = "coverage-7.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7c211f25777746d468d76f11719e64acb40eed410d81c26cefac641975beb88"},
- {file = "coverage-7.5.0-cp38-cp38-win32.whl", hash = "sha256:262fffc1f6c1a26125d5d573e1ec379285a3723363f3bd9c83923c9593a2ac25"},
- {file = "coverage-7.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:eed462b4541c540d63ab57b3fc69e7d8c84d5957668854ee4e408b50e92ce26a"},
- {file = "coverage-7.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d0194d654e360b3e6cc9b774e83235bae6b9b2cac3be09040880bb0e8a88f4a1"},
- {file = "coverage-7.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:33c020d3322662e74bc507fb11488773a96894aa82a622c35a5a28673c0c26f5"},
- {file = "coverage-7.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbdf2cae14a06827bec50bd58e49249452d211d9caddd8bd80e35b53cb04631"},
- {file = "coverage-7.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3235d7c781232e525b0761730e052388a01548bd7f67d0067a253887c6e8df46"},
- {file = "coverage-7.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2de4e546f0ec4b2787d625e0b16b78e99c3e21bc1722b4977c0dddf11ca84e"},
- {file = "coverage-7.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4d0e206259b73af35c4ec1319fd04003776e11e859936658cb6ceffdeba0f5be"},
- {file = "coverage-7.5.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2055c4fb9a6ff624253d432aa471a37202cd8f458c033d6d989be4499aed037b"},
- {file = "coverage-7.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:075299460948cd12722a970c7eae43d25d37989da682997687b34ae6b87c0ef0"},
- {file = "coverage-7.5.0-cp39-cp39-win32.whl", hash = "sha256:280132aada3bc2f0fac939a5771db4fbb84f245cb35b94fae4994d4c1f80dae7"},
- {file = "coverage-7.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:c58536f6892559e030e6924896a44098bc1290663ea12532c78cef71d0df8493"},
- {file = "coverage-7.5.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:2b57780b51084d5223eee7b59f0d4911c31c16ee5aa12737c7a02455829ff067"},
- {file = "coverage-7.5.0.tar.gz", hash = "sha256:cf62d17310f34084c59c01e027259076479128d11e4661bb6c9acb38c5e19bb8"},
+ {file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
+ {file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
+ {file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
+ {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
+ {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
+ {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
+ {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
+ {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
+ {file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
+ {file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
+ {file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
+ {file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
+ {file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
+ {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
+ {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
+ {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
+ {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
+ {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
+ {file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
+ {file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
+ {file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
+ {file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
+ {file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
+ {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
+ {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
+ {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
+ {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
+ {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
+ {file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
+ {file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
+ {file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
+ {file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
+ {file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
+ {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
+ {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
+ {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
+ {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
+ {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
+ {file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
+ {file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
+ {file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
+ {file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
+ {file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
+ {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
+ {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
+ {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
+ {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
+ {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
+ {file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
+ {file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
+ {file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
+ {file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
]
[package.dependencies]
@@ -522,13 +511,13 @@ toml = ["tomli"]
[[package]]
name = "datasets"
-version = "2.19.0"
+version = "2.19.1"
description = "HuggingFace community-driven open-source library of datasets"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "datasets-2.19.0-py3-none-any.whl", hash = "sha256:f57c5316e123d4721b970c68c1cb856505f289cda58f5557ffe745b49c011a8e"},
- {file = "datasets-2.19.0.tar.gz", hash = "sha256:0b47e08cc7af2c6800a42cadc4657b22a0afc7197786c8986d703c08d90886a6"},
+ {file = "datasets-2.19.1-py3-none-any.whl", hash = "sha256:f7a78d15896f45004ccac1c298f3c7121f92f91f6f2bfbd4e4f210f827e6e411"},
+ {file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"},
]
[package.dependencies]
@@ -778,16 +767,6 @@ files = [
[package.dependencies]
six = ">=1.4.0"
-[[package]]
-name = "egl-probe"
-version = "1.0.2"
-description = ""
-optional = false
-python-versions = "*"
-files = [
- {file = "egl_probe-1.0.2.tar.gz", hash = "sha256:29bdca7b08da1e060cfb42cd46af8300a7ac4f3b1b2eeb16e545ea16d9a5ac93"},
-]
-
[[package]]
name = "einops"
version = "0.8.0"
@@ -1048,127 +1027,69 @@ files = [
[package.extras]
preview = ["glfw-preview"]
-[[package]]
-name = "grpcio"
-version = "1.63.0"
-description = "HTTP/2-based RPC framework"
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"},
- {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"},
- {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"},
- {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"},
- {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"},
- {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"},
- {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"},
- {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"},
- {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"},
- {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"},
- {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"},
- {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"},
- {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"},
- {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"},
- {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"},
- {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"},
- {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"},
- {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"},
- {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"},
- {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"},
- {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"},
- {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"},
- {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"},
- {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"},
- {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"},
- {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"},
- {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"},
- {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"},
- {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"},
- {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"},
- {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"},
- {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"},
- {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"},
- {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"},
- {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"},
- {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"},
- {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"},
- {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"},
- {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"},
- {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"},
- {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"},
- {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"},
- {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"},
- {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"},
- {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"},
- {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"},
-]
-
-[package.extras]
-protobuf = ["grpcio-tools (>=1.63.0)"]
-
[[package]]
name = "gym-aloha"
-version = "0.1.0"
+version = "0.1.1"
description = "A gym environment for ALOHA"
optional = true
python-versions = "<4.0,>=3.10"
files = [
- {file = "gym_aloha-0.1.0-py3-none-any.whl", hash = "sha256:62e36eeb09284422cbb7baca0292c6f65e38ec8774bf9b0bf7159ad5990cf29a"},
- {file = "gym_aloha-0.1.0.tar.gz", hash = "sha256:bab332f469ba5ffe655fc3e9647aead05d2cb3b950dfb1f299b9539b3857ad7e"},
+ {file = "gym_aloha-0.1.1-py3-none-any.whl", hash = "sha256:2698037246dbb106828f0bc229b61007b0a21d5967c72cc373f7bc1083203584"},
+ {file = "gym_aloha-0.1.1.tar.gz", hash = "sha256:614ae1cf116323e7b5ae2f0e9bd282c4f052aee15e839e5587ddce45995359bc"},
]
[package.dependencies]
-dm-control = "1.0.14"
-gymnasium = ">=0.29.1,<0.30.0"
-imageio = {version = ">=2.34.0,<3.0.0", extras = ["ffmpeg"]}
+dm-control = ">=1.0.14"
+gymnasium = ">=0.29.1"
+imageio = {version = ">=2.34.0", extras = ["ffmpeg"]}
mujoco = ">=2.3.7,<3.0.0"
[package.extras]
-dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"]
-test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"]
+dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
+test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
[[package]]
name = "gym-pusht"
-version = "0.1.0"
+version = "0.1.3"
description = "A gymnasium environment for PushT."
optional = true
python-versions = "<4.0,>=3.10"
files = [
- {file = "gym_pusht-0.1.0-py3-none-any.whl", hash = "sha256:ad0841f3b03741eb08985b703eab4931653a44fd255c26e2c82a7f71cc4dc210"},
- {file = "gym_pusht-0.1.0.tar.gz", hash = "sha256:82b26ea753f28f0277c06839c68970a4ea872fae4710402349f872ae45de83f4"},
+ {file = "gym_pusht-0.1.3-py3-none-any.whl", hash = "sha256:feeb02493a03d1aacc45d43d6397962c50ed779ab7e4019d73af11d2f0b3831b"},
+ {file = "gym_pusht-0.1.3.tar.gz", hash = "sha256:c8e9a5256035ba49841ebbc7c32a06c4fa2daa52f5fad80da941b607c4553e28"},
]
[package.dependencies]
-gymnasium = ">=0.29.1,<0.30.0"
-opencv-python = ">=4.9.0.80,<5.0.0.0"
-pygame = ">=2.5.2,<3.0.0"
-pymunk = ">=6.6.0,<7.0.0"
-scikit-image = ">=0.22.0,<0.23.0"
-shapely = ">=2.0.3,<3.0.0"
+gymnasium = ">=0.29.1"
+opencv-python = ">=4.9.0"
+pygame = ">=2.5.2"
+pymunk = ">=6.6.0"
+scikit-image = ">=0.22.0"
+shapely = ">=2.0.3"
[package.extras]
-dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"]
-test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"]
+dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
+test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
[[package]]
name = "gym-xarm"
-version = "0.1.0"
+version = "0.1.1"
description = "A gym environment for xArm"
optional = true
python-versions = "<4.0,>=3.10"
files = [
- {file = "gym_xarm-0.1.0-py3-none-any.whl", hash = "sha256:d10ac19a59d302201a9b8bd913530211b1058467b787ad91a657907e40cdbc13"},
- {file = "gym_xarm-0.1.0.tar.gz", hash = "sha256:fc05f9d02af1f0205275311669dc191ce431be484e221a96401eb544764eb986"},
+ {file = "gym_xarm-0.1.1-py3-none-any.whl", hash = "sha256:3bd7e3c1c5521ba80a56536f01a5e11321580704d72160355ce47a828a8808ad"},
+ {file = "gym_xarm-0.1.1.tar.gz", hash = "sha256:e455524561b02d06b92a4f7d524f448d84a7484d9a2dbc78600e3c66240e0fb7"},
]
[package.dependencies]
-gymnasium = ">=0.29.1,<0.30.0"
-gymnasium-robotics = ">=1.2.4,<2.0.0"
+gymnasium = ">=0.29.1"
+gymnasium-robotics = ">=1.2.4"
mujoco = ">=2.3.7,<3.0.0"
[package.extras]
-dev = ["debugpy (>=1.8.1,<2.0.0)", "pre-commit (>=3.7.0,<4.0.0)"]
-test = ["pytest (>=8.1.0,<9.0.0)", "pytest-cov (>=5.0.0,<6.0.0)"]
+dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
+test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
[[package]]
name = "gymnasium"
@@ -1258,13 +1179,13 @@ numpy = ">=1.17.3"
[[package]]
name = "huggingface-hub"
-version = "0.21.4"
+version = "0.23.0"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
- {file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"},
+ {file = "huggingface_hub-0.23.0-py3-none-any.whl", hash = "sha256:075c30d48ee7db2bba779190dc526d2c11d422aed6f9044c5e2fdc2c432fdb91"},
+ {file = "huggingface_hub-0.23.0.tar.gz", hash = "sha256:7126dedd10a4c6fac796ced4d87a8cf004efc722a5125c2c09299017fa366fa9"},
]
[package.dependencies]
@@ -1277,15 +1198,16 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3"
[package.extras]
-all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"]
-dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
hf-transfer = ["hf-transfer (>=0.1.4)"]
-inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"]
-quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"]
+inference = ["aiohttp", "minijinja (>=1.0)"]
+quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
-testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+tensorflow-testing = ["keras (<3.0)", "tensorflow"]
+testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
torch = ["safetensors", "torch"]
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
@@ -1471,13 +1393,13 @@ files = [
[[package]]
name = "jinja2"
-version = "3.1.3"
+version = "3.1.4"
description = "A very fast and expressive template engine."
optional = false
python-versions = ">=3.7"
files = [
- {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"},
- {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"},
+ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"},
+ {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"},
]
[package.dependencies]
@@ -1678,7 +1600,6 @@ files = [
{file = "lxml-5.2.1-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:9e2addd2d1866fe112bc6f80117bcc6bc25191c5ed1bfbcf9f1386a884252ae8"},
{file = "lxml-5.2.1-cp37-cp37m-win32.whl", hash = "sha256:f51969bac61441fd31f028d7b3b45962f3ecebf691a510495e5d2cd8c8092dbd"},
{file = "lxml-5.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b0b58fbfa1bf7367dde8a557994e3b1637294be6cf2169810375caf8571a085c"},
- {file = "lxml-5.2.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e183c6e3298a2ed5af9d7a356ea823bccaab4ec2349dc9ed83999fd289d14d5"},
{file = "lxml-5.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:804f74efe22b6a227306dd890eecc4f8c59ff25ca35f1f14e7482bbce96ef10b"},
{file = "lxml-5.2.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08802f0c56ed150cc6885ae0788a321b73505d2263ee56dad84d200cab11c07a"},
{file = "lxml-5.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8c09ed18ecb4ebf23e02b8e7a22a05d6411911e6fabef3a36e4f371f4f2585"},
@@ -1750,21 +1671,6 @@ html5 = ["html5lib"]
htmlsoup = ["BeautifulSoup4"]
source = ["Cython (>=3.0.10)"]
-[[package]]
-name = "markdown"
-version = "3.6"
-description = "Python implementation of John Gruber's Markdown."
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"},
- {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"},
-]
-
-[package.extras]
-docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
-testing = ["coverage", "pyyaml"]
-
[[package]]
name = "markupsafe"
version = "2.1.5"
@@ -2407,6 +2313,7 @@ optional = false
python-versions = ">=3.9"
files = [
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
+ {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
@@ -2427,6 +2334,7 @@ files = [
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
+ {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
@@ -2585,7 +2493,7 @@ xmp = ["defusedxml"]
name = "platformdirs"
version = "4.2.1"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
-optional = true
+optional = false
python-versions = ">=3.8"
files = [
{file = "platformdirs-4.2.1-py3-none-any.whl", hash = "sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1"},
@@ -2614,13 +2522,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pre-commit"
-version = "3.7.0"
+version = "3.7.1"
description = "A framework for managing and maintaining multi-language pre-commit hooks."
optional = true
python-versions = ">=3.9"
files = [
- {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"},
- {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"},
+ {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"},
+ {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"},
]
[package.dependencies]
@@ -2868,74 +2776,74 @@ files = [
[[package]]
name = "pymunk"
-version = "6.6.0"
-description = "Pymunk is a easy-to-use pythonic 2d physics library"
+version = "6.8.0"
+description = "Pymunk is a easy-to-use pythonic 2D physics library"
optional = false
python-versions = ">=3.7"
files = [
- {file = "pymunk-6.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6da50dd97683337a290110d594fad07a75153d2d837b570ef972478d739c33f8"},
- {file = "pymunk-6.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bcd7d16a2b4d51d45d6780a701f65c8d5b36fdf545c3f4738910da41e2a9c4ee"},
- {file = "pymunk-6.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32c91a783b645267518588515acdc3ff315135297eef39386d488c4ff2a7c139"},
- {file = "pymunk-6.6.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74694f92f46fe54e2c033b598b2c38185f456711888955aa3f67003692a3ef91"},
- {file = "pymunk-6.6.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fe011afb3f7594a679ba35dc7a44e12c8c8aacb55e58d54f14bfe8b82959695c"},
- {file = "pymunk-6.6.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:60e5cc6e33f7e880418f75a7d6b5ac3eed47396bbe7c68ca47c389de3b5d1d3a"},
- {file = "pymunk-6.6.0-cp310-cp310-win32.whl", hash = "sha256:10518074e33d4fe723bce795f705ad3e850ecec9987559ec3fa072a6539c47ad"},
- {file = "pymunk-6.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b163b28f9500df1bb5e123e2dba2d1f255e63be6ca098544936a93c05022a43"},
- {file = "pymunk-6.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8322594fc68858bfc0142f2f7a100cfb4edb85678a75983ce2fc58ed763afb96"},
- {file = "pymunk-6.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c1d0be60b781d1b8bb11303b25936d01cdef7ccfcc3a68b0c2fd689f63ac11c"},
- {file = "pymunk-6.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9462200c47f3eb344373077dc01384cb16355a982ce0e33571201f3b7ee44487"},
- {file = "pymunk-6.6.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ede46cc44432b1316a402129fc225743f7e9f502d0d055790eab877627ddfd98"},
- {file = "pymunk-6.6.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3582cd67d6ac16f122d2b7100e0b00d9b55f97a0a7e21336df885166e2bffdc3"},
- {file = "pymunk-6.6.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e4247ede965df39d2fd7ae25e9360762cce61f4d39b95af91d29c1c556c80777"},
- {file = "pymunk-6.6.0-cp311-cp311-win32.whl", hash = "sha256:a77f9bb634ab216ac8991f73aa68b4dadfd6690e8cb17627a6646dc8fecd6126"},
- {file = "pymunk-6.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:2f579e8c5498b3e8c0686841f1f5e3adf1bdd32b339ee36001ebae19bbafc008"},
- {file = "pymunk-6.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ad6ca584a9ea1d6a1536ae158350d73dbbdc637f302a86019b7fb299120439c4"},
- {file = "pymunk-6.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b86be4ecfb86d4af26c3dd2e390884305c3b8604e5df8550fbb2968d3ac78411"},
- {file = "pymunk-6.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68006cfb71351b6f23a81f541a2eca56596e69977e051e46cfe93a5ffdc410ef"},
- {file = "pymunk-6.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:361d2fa43e65aa5e47dcb50e6b058b3814e19cbdb5bf062d2da78c2b3bdba192"},
- {file = "pymunk-6.6.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3c35975f4172b024e0bb1be6f57f1048dcb469a8cf257c30123d11a9fe57e2a"},
- {file = "pymunk-6.6.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:35a57546294656b5bb989e08426a4926e26a17726aef35daf34c2703ee54c0e9"},
- {file = "pymunk-6.6.0-cp312-cp312-win32.whl", hash = "sha256:a68480440b60bf5acf3a7a8db1eb571e13ed425d5b693a20020f2efa9cc09592"},
- {file = "pymunk-6.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:f7ed11a1e2a306e4213d88a1879ae0fb7c2c983a890fa1b35ed26b9392213c02"},
- {file = "pymunk-6.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:49961e339571d14afa9ebc815190ebfdda69e6ffd433536451bb07d6bc55e430"},
- {file = "pymunk-6.6.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70d9d5edcb2e90eeea0afb322c82d75a02e6bb77a9ff08b86daa2245a2c2a4ef"},
- {file = "pymunk-6.6.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:622746251dda14507d3655b64c93a4509125c0a651265c473945f227ba5763ec"},
- {file = "pymunk-6.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:35277485eb69cc5dda3f15b139723c77d69b9271f9fedf4264d08e8afdea67d0"},
- {file = "pymunk-6.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ceadcf03988c51697a3357d6dd3c96dd60e48b993734346edf8955fcd3770466"},
- {file = "pymunk-6.6.0-cp37-cp37m-win32.whl", hash = "sha256:0d8e0e79135e86b6e0e686fd287f297488e728cb8276fc713cb33fdd7ce4f5f2"},
- {file = "pymunk-6.6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:259a371150a9e264851d0a9caa85b5a19ba661f364da630a231a3eb326d49ca1"},
- {file = "pymunk-6.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e97ad1ce7fa3e9ea15622d1e0c45e2757f02e1c947a354888c2014799575c100"},
- {file = "pymunk-6.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4cd70ffc259b8069eabb54ed5c7cbc39d0f5158610791c14ad0437f6cb6d18d"},
- {file = "pymunk-6.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cf85e95774a89b2adf084c0129d62f69eaa23b97b800892ddcfa7862b931bbb"},
- {file = "pymunk-6.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8d60b9fcde952d6e25c740a1ced5612ace59fa85e578986f7f053a538a681ed"},
- {file = "pymunk-6.6.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5652c423ea2769b1d44e33fd2b19f2a6f7f4a34acacec9a86b63c780ac611552"},
- {file = "pymunk-6.6.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:939e46d9021cb5bc6ac4dbecafe89245af2b8325787869983b0a99181e37fd39"},
- {file = "pymunk-6.6.0-cp38-cp38-win32.whl", hash = "sha256:7785f5ac0597be5693dd2da819233297984d324b6470bf31b76c71399f25a18e"},
- {file = "pymunk-6.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:26a0834207785878ba2bb244ab5616d9b6e09d01c2f19641f10247ca22d3c10e"},
- {file = "pymunk-6.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ed23f05a65687750cba4d6cde045147d28eee84e44cd33829b79601dc655adf3"},
- {file = "pymunk-6.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b2c23f2f182f91944c4ba5cfd6f652e873e6e8b113506c3eca255df5e6c79b6e"},
- {file = "pymunk-6.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c51b6de5869dcb103467d8ac75f62a1a9f43faa18bf12e37e89247b2d5554a61"},
- {file = "pymunk-6.6.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b44c4420b43cfdfedd2278e3beb60970a9a9564f1272c7cc74090931268ab43"},
- {file = "pymunk-6.6.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4b8d9c14275fd4853ae863e38bec8a7ae4c7aef4417550ff74fc9f68f120fa00"},
- {file = "pymunk-6.6.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:31631a91bf29dde9c4a5f1283056cb91d451fe352f35a440c5cb668b0de19ad5"},
- {file = "pymunk-6.6.0-cp39-cp39-win32.whl", hash = "sha256:832d83570d0781e2bcba555b0974e9a5f9ee592079dfd3b183a493cf0ceaac7f"},
- {file = "pymunk-6.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:88625cca15c90dc8c0c1b55113f0ff19a8e6601ac0981804d317660c0afde9e2"},
- {file = "pymunk-6.6.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8e27a8c7b762d43e91f18c320ad849c113dead500184d151aa14bd11a62c2c47"},
- {file = "pymunk-6.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aad898ca45546f084b0d88f73c771e3de0d19acc65f1171a9dbdba171945a915"},
- {file = "pymunk-6.6.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:45f537c79e817330753e6ed220b3ff46b5b983266d5b85ce7c1381a77b33d1f3"},
- {file = "pymunk-6.6.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:609341ff1329e59ee7a67b622973064c213111e87916981bc45838f38981ba47"},
- {file = "pymunk-6.6.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:345b99d19cb848359fbefcaba54a5f1bcc8dd05b084563d693ca4d0622aa1079"},
- {file = "pymunk-6.6.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f33c418b758e06960fa28e0434c14818c0d9755f431045db05cc93e646df9b22"},
- {file = "pymunk-6.6.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:59991310cb1a6f201878e9519cbb36ff746f825c9fac49fa76cf8c85b64bf7ad"},
- {file = "pymunk-6.6.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c7513caf1add221cfa1228c12e14e0997a7212e583a59f517b68e72b1f02e08f"},
- {file = "pymunk-6.6.0.tar.gz", hash = "sha256:89be7b6ba237e313c440edfb99612de59bf119e43976d5c76802907cb7a3911c"},
+ {file = "pymunk-6.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:96b363241569e616bea1d62346552e3503f2b58d6715b5bda20a13ba2522cea2"},
+ {file = "pymunk-6.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:91aab5c0640cddaf100eaaf6df8578aa3b97b70c5e5e0c1d26f6826eefec8e96"},
+ {file = "pymunk-6.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78973d4ef0fa715e127ca408f6063c82722b883d98cf216dddd906aa00debf4e"},
+ {file = "pymunk-6.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7944db455d81bc552fc1b2edadbf82a6b91b11ee75193552ef629d0b8106975b"},
+ {file = "pymunk-6.8.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3ec2a1c2a3b2b0eac97ebd2b84dfd51bd98b2b753a33dce81f4f26fa9e1a8974"},
+ {file = "pymunk-6.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3134ba935921e0e888d9013258f1de111bc2da05a02ec2a7d520a8c580f28fba"},
+ {file = "pymunk-6.8.0-cp310-cp310-win32.whl", hash = "sha256:a3222cb84481408faf6236f4cea6a17045881e5780a4dccc170344a7a4ad160a"},
+ {file = "pymunk-6.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:7f13bffe06634f23e0453b0a5e388142cdbaa06f54a243aae98e2b9c2793ebb0"},
+ {file = "pymunk-6.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b0db72af65205c80d15403e6784bdd7c559e620411394449356dc869b90ade1c"},
+ {file = "pymunk-6.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d6e5cc30d15eba4bc33e00b9fb8c366255dac47fe486f5276f6334e8a8c34754"},
+ {file = "pymunk-6.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3752e8486a4b8bfaa988be59a9773f0371d0cd52a7901fe7ba9caed1ea5b6129"},
+ {file = "pymunk-6.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0f20b3dfc56aee57cc910ce65505286e184e045c9169bd8d5eff50974b6a1e97"},
+ {file = "pymunk-6.8.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6d807bba8fd7bae35754c29c2cb8809f0cf7817541c8cb4d134872e102899724"},
+ {file = "pymunk-6.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ac39afd90800061b353cbeb71171109ef7411cc88f5607b10447b2906e4cde97"},
+ {file = "pymunk-6.8.0-cp311-cp311-win32.whl", hash = "sha256:4de7683f832c694b82dbe7c20469765663f06ee82f8b711bd72731684b452889"},
+ {file = "pymunk-6.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:6620fc590290c37e58c8495fb74e5eb433f24b7c7d46c8a7b4b54c56ca9990ab"},
+ {file = "pymunk-6.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6eee05919e1e8f63b64a75362e82691365918b0ea33af11d2b3aab1d81402a3d"},
+ {file = "pymunk-6.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b99fa28a8fa5242560a7f2d59604b9e55eed865d8308dd5f93af97ad2605f84"},
+ {file = "pymunk-6.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9be0614a4c0eaaff209656df464e9eb5653dc171a15230dd4d307a3f3564e6"},
+ {file = "pymunk-6.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c22598b75ef99dd70fb2c6b8989c55ab00fac379555ebf68cfe7adfa65fd94e"},
+ {file = "pymunk-6.8.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:533bb7555df1c904f1056ac9149e59ab039ee195fa22c69500843ef7e3f57062"},
+ {file = "pymunk-6.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9de831180f7650c47fcfcadadf5d295e72065c8500a4c9a78f6d37229c2ca58b"},
+ {file = "pymunk-6.8.0-cp312-cp312-win32.whl", hash = "sha256:8bbc9189c71a6c51825f8246e94c6642823ef42a4b3ed9c2afa7f8ec48425929"},
+ {file = "pymunk-6.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:e3b61a162153dfdd0ebbab82eddb417da812085b3587ebf92a225d33df8e044d"},
+ {file = "pymunk-6.8.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8a4fa8e3672a3b49850adf71b0eabacabb73db0514cbece0649bc77e1a124924"},
+ {file = "pymunk-6.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ada172ee87296fdfdf5ac0d88c43502b482697185ce9b6d03d0f0d4b5b11532"},
+ {file = "pymunk-6.8.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c0c9d881ca7d9a9375ce248e90f24efb298d41e4abf8e16f5c7e78c66983c34"},
+ {file = "pymunk-6.8.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e5a6e2a7ff652b2977e24b4ed2d33fc7d628bd4e54ddeb488515b1475f715d91"},
+ {file = "pymunk-6.8.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:12f0af3417a95c5ab97207a5e54fbc91a54c801b4560283532f5582401a1f36e"},
+ {file = "pymunk-6.8.0-cp37-cp37m-win32.whl", hash = "sha256:382aaa71d7615ded7cfc644a091391cf0fd3ecf7bc556e0145d0f6982c942ee7"},
+ {file = "pymunk-6.8.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c40361a2b017107303568ace3e506d87ab46d67d31484f656ba7792901d20abd"},
+ {file = "pymunk-6.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54167f0f9b65a49f35fbde56746ce7020b74b39b86ef8cec8804ef9422d258c9"},
+ {file = "pymunk-6.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3e32c520ba2729c97fd5031cc0caa99de20a9b6dda85b8468cf769afa7a4a85c"},
+ {file = "pymunk-6.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93c86daf68fb0785722cbea3fc36afcf9830af156de9ed99cbf2b6d6475240ab"},
+ {file = "pymunk-6.8.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:269576ad7d4241c75611df0f0e3ff0b19de436f4facabb21185e579a573c91d0"},
+ {file = "pymunk-6.8.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a4c90c82e6dfd35930ad779287859c3a867518065fce97fee6eeaf81a1754ea6"},
+ {file = "pymunk-6.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06b5ce0a26a26b3490107c962632f4dd53137df14e11b0e55c9816005509dea1"},
+ {file = "pymunk-6.8.0-cp38-cp38-win32.whl", hash = "sha256:ff3b4d086f47f4fee9980977ec4f1121909b5456ed05fcad3c0f2f6e224e1fef"},
+ {file = "pymunk-6.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d8f7124ab9e5c69ab698d3836dc0d84e1a31d47b5e7ce6477cf5205d701887"},
+ {file = "pymunk-6.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:47803f8229b95e9ae56cada9566e1b92b2789affe2229ed623d3a871fd307982"},
+ {file = "pymunk-6.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:194cf34146b7393ebdd2e37cb50d5579e737baf378f900a50ff477c909a163c9"},
+ {file = "pymunk-6.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af8d9d0a26dc4d504ac881407ac9d7689b0c89bf8c0535717c15583773eb965"},
+ {file = "pymunk-6.8.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79c50449c12120d05fdf515e4c937220580f945ed1eda8c374734a3418fc18e6"},
+ {file = "pymunk-6.8.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:45197c7fcebff6bc3d4b7f3ccef150a6c4c43f71860e03503e851c8ecc0af861"},
+ {file = "pymunk-6.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:183ecbbafffe8a071ffb0efb6a5daa21f05d2b9a36c0538b47fbd15b6e6fa6e5"},
+ {file = "pymunk-6.8.0-cp39-cp39-win32.whl", hash = "sha256:2453eff73e474c1f282088e70a5dfd970ebc9f565c1b39b1da69df1b43dee47f"},
+ {file = "pymunk-6.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:769aae66d3918fa7d9ac33fca4e693a53aba1ed801050450c1a42b4c8ecc7250"},
+ {file = "pymunk-6.8.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:22153b93074e8f397e245aad2811e50ccc94502616341a1420c2a3a7332c1db0"},
+ {file = "pymunk-6.8.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b6d5d62b1f8ad3b8626be64817bed709edbbd03b578e33ae3e39ab7f9301055"},
+ {file = "pymunk-6.8.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50d9ff9e7caa3a7f432f4c6da4d1acd0dd1208ca22fc4cea3d48d845ac4111b3"},
+ {file = "pymunk-6.8.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:b0688613c641c5a018091ba9285a38cb6e53a64daa9ec3bc80006ea6c4531a32"},
+ {file = "pymunk-6.8.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:42cf265d55dd90ab3441c9e0a1596094372b063713760d2a5489321d1b9beddb"},
+ {file = "pymunk-6.8.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f199f1f291f7ad8ec3c00d529c6207bb7a67497b6ecd79ffb27df1aabe973d7"},
+ {file = "pymunk-6.8.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79fe51090090f4dd4e41603bbc46926be4bae3c3f554664b907fc3fda65019f8"},
+ {file = "pymunk-6.8.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:03100f749d276742244d560e12ae125cbcf13606d24fd60455b52d5b9a9f7a17"},
+ {file = "pymunk-6.8.0.tar.gz", hash = "sha256:882929eac3cc5107bec13da7bbe9b6a3868df87ecc373475d0d1aae82d2f5dda"},
]
[package.dependencies]
cffi = ">=1.15.0"
[package.extras]
-dev = ["aafigure", "matplotlib", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"]
+dev = ["aafigure", "matplotlib", "numpy", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"]
[[package]]
name = "pyopengl"
@@ -3101,90 +3009,90 @@ files = [
[[package]]
name = "regex"
-version = "2024.4.28"
+version = "2024.5.10"
description = "Alternative regular expression module, to replace re."
optional = false
python-versions = ">=3.8"
files = [
- {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd196d056b40af073d95a2879678585f0b74ad35190fac04ca67954c582c6b61"},
- {file = "regex-2024.4.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8bb381f777351bd534462f63e1c6afb10a7caa9fa2a421ae22c26e796fe31b1f"},
- {file = "regex-2024.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:47af45b6153522733aa6e92543938e97a70ce0900649ba626cf5aad290b737b6"},
- {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99d6a550425cc51c656331af0e2b1651e90eaaa23fb4acde577cf15068e2e20f"},
- {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf29304a8011feb58913c382902fde3395957a47645bf848eea695839aa101b7"},
- {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92da587eee39a52c91aebea8b850e4e4f095fe5928d415cb7ed656b3460ae79a"},
- {file = "regex-2024.4.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6277d426e2f31bdbacb377d17a7475e32b2d7d1f02faaecc48d8e370c6a3ff31"},
- {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28e1f28d07220c0f3da0e8fcd5a115bbb53f8b55cecf9bec0c946eb9a059a94c"},
- {file = "regex-2024.4.28-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aaa179975a64790c1f2701ac562b5eeb733946eeb036b5bcca05c8d928a62f10"},
- {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6f435946b7bf7a1b438b4e6b149b947c837cb23c704e780c19ba3e6855dbbdd3"},
- {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:19d6c11bf35a6ad077eb23852827f91c804eeb71ecb85db4ee1386825b9dc4db"},
- {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:fdae0120cddc839eb8e3c15faa8ad541cc6d906d3eb24d82fb041cfe2807bc1e"},
- {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e672cf9caaf669053121f1766d659a8813bd547edef6e009205378faf45c67b8"},
- {file = "regex-2024.4.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f57515750d07e14743db55d59759893fdb21d2668f39e549a7d6cad5d70f9fea"},
- {file = "regex-2024.4.28-cp310-cp310-win32.whl", hash = "sha256:a1409c4eccb6981c7baabc8888d3550df518add6e06fe74fa1d9312c1838652d"},
- {file = "regex-2024.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:1f687a28640f763f23f8a9801fe9e1b37338bb1ca5d564ddd41619458f1f22d1"},
- {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:84077821c85f222362b72fdc44f7a3a13587a013a45cf14534df1cbbdc9a6796"},
- {file = "regex-2024.4.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b45d4503de8f4f3dc02f1d28a9b039e5504a02cc18906cfe744c11def942e9eb"},
- {file = "regex-2024.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:457c2cd5a646dd4ed536c92b535d73548fb8e216ebee602aa9f48e068fc393f3"},
- {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b51739ddfd013c6f657b55a508de8b9ea78b56d22b236052c3a85a675102dc6"},
- {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:459226445c7d7454981c4c0ce0ad1a72e1e751c3e417f305722bbcee6697e06a"},
- {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:670fa596984b08a4a769491cbdf22350431970d0112e03d7e4eeaecaafcd0fec"},
- {file = "regex-2024.4.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00f4fe11c8a521b173e6324d862ee7ee3412bf7107570c9b564fe1119b56fb"},
- {file = "regex-2024.4.28-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36f392dc7763fe7924575475736bddf9ab9f7a66b920932d0ea50c2ded2f5636"},
- {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:23a412b7b1a7063f81a742463f38821097b6a37ce1e5b89dd8e871d14dbfd86b"},
- {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f1d6e4b7b2ae3a6a9df53efbf199e4bfcff0959dbdb5fd9ced34d4407348e39a"},
- {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:499334ad139557de97cbc4347ee921c0e2b5e9c0f009859e74f3f77918339257"},
- {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0940038bec2fe9e26b203d636c44d31dd8766abc1fe66262da6484bd82461ccf"},
- {file = "regex-2024.4.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:66372c2a01782c5fe8e04bff4a2a0121a9897e19223d9eab30c54c50b2ebeb7f"},
- {file = "regex-2024.4.28-cp311-cp311-win32.whl", hash = "sha256:c77d10ec3c1cf328b2f501ca32583625987ea0f23a0c2a49b37a39ee5c4c4630"},
- {file = "regex-2024.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:fc0916c4295c64d6890a46e02d4482bb5ccf33bf1a824c0eaa9e83b148291f90"},
- {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:08a1749f04fee2811c7617fdd46d2e46d09106fa8f475c884b65c01326eb15c5"},
- {file = "regex-2024.4.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b8eb28995771c087a73338f695a08c9abfdf723d185e57b97f6175c5051ff1ae"},
- {file = "regex-2024.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd7ef715ccb8040954d44cfeff17e6b8e9f79c8019daae2fd30a8806ef5435c0"},
- {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb0315a2b26fde4005a7c401707c5352df274460f2f85b209cf6024271373013"},
- {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fc053228a6bd3a17a9b0a3f15c3ab3cf95727b00557e92e1cfe094b88cc662"},
- {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7fe9739a686dc44733d52d6e4f7b9c77b285e49edf8570754b322bca6b85b4cc"},
- {file = "regex-2024.4.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74fcf77d979364f9b69fcf8200849ca29a374973dc193a7317698aa37d8b01c"},
- {file = "regex-2024.4.28-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:965fd0cf4694d76f6564896b422724ec7b959ef927a7cb187fc6b3f4e4f59833"},
- {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2fef0b38c34ae675fcbb1b5db760d40c3fc3612cfa186e9e50df5782cac02bcd"},
- {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bc365ce25f6c7c5ed70e4bc674f9137f52b7dd6a125037f9132a7be52b8a252f"},
- {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ac69b394764bb857429b031d29d9604842bc4cbfd964d764b1af1868eeebc4f0"},
- {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:144a1fc54765f5c5c36d6d4b073299832aa1ec6a746a6452c3ee7b46b3d3b11d"},
- {file = "regex-2024.4.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2630ca4e152c221072fd4a56d4622b5ada876f668ecd24d5ab62544ae6793ed6"},
- {file = "regex-2024.4.28-cp312-cp312-win32.whl", hash = "sha256:7f3502f03b4da52bbe8ba962621daa846f38489cae5c4a7b5d738f15f6443d17"},
- {file = "regex-2024.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:0dd3f69098511e71880fb00f5815db9ed0ef62c05775395968299cb400aeab82"},
- {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:374f690e1dd0dbdcddea4a5c9bdd97632cf656c69113f7cd6a361f2a67221cb6"},
- {file = "regex-2024.4.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:25f87ae6b96374db20f180eab083aafe419b194e96e4f282c40191e71980c666"},
- {file = "regex-2024.4.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5dbc1bcc7413eebe5f18196e22804a3be1bfdfc7e2afd415e12c068624d48247"},
- {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f85151ec5a232335f1be022b09fbbe459042ea1951d8a48fef251223fc67eee1"},
- {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57ba112e5530530fd175ed550373eb263db4ca98b5f00694d73b18b9a02e7185"},
- {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:224803b74aab56aa7be313f92a8d9911dcade37e5f167db62a738d0c85fdac4b"},
- {file = "regex-2024.4.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a54a047b607fd2d2d52a05e6ad294602f1e0dec2291152b745870afc47c1397"},
- {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a2a512d623f1f2d01d881513af9fc6a7c46e5cfffb7dc50c38ce959f9246c94"},
- {file = "regex-2024.4.28-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c06bf3f38f0707592898428636cbb75d0a846651b053a1cf748763e3063a6925"},
- {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1031a5e7b048ee371ab3653aad3030ecfad6ee9ecdc85f0242c57751a05b0ac4"},
- {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d7a353ebfa7154c871a35caca7bfd8f9e18666829a1dc187115b80e35a29393e"},
- {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:7e76b9cfbf5ced1aca15a0e5b6f229344d9b3123439ffce552b11faab0114a02"},
- {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5ce479ecc068bc2a74cb98dd8dba99e070d1b2f4a8371a7dfe631f85db70fe6e"},
- {file = "regex-2024.4.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7d77b6f63f806578c604dca209280e4c54f0fa9a8128bb8d2cc5fb6f99da4150"},
- {file = "regex-2024.4.28-cp38-cp38-win32.whl", hash = "sha256:d84308f097d7a513359757c69707ad339da799e53b7393819ec2ea36bc4beb58"},
- {file = "regex-2024.4.28-cp38-cp38-win_amd64.whl", hash = "sha256:2cc1b87bba1dd1a898e664a31012725e48af826bf3971e786c53e32e02adae6c"},
- {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7413167c507a768eafb5424413c5b2f515c606be5bb4ef8c5dee43925aa5718b"},
- {file = "regex-2024.4.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:108e2dcf0b53a7c4ab8986842a8edcb8ab2e59919a74ff51c296772e8e74d0ae"},
- {file = "regex-2024.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f1c5742c31ba7d72f2dedf7968998730664b45e38827637e0f04a2ac7de2f5f1"},
- {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc6148228c9ae25ce403eade13a0961de1cb016bdb35c6eafd8e7b87ad028b1"},
- {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7d893c8cf0e2429b823ef1a1d360a25950ed11f0e2a9df2b5198821832e1947"},
- {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4290035b169578ffbbfa50d904d26bec16a94526071ebec3dadbebf67a26b25e"},
- {file = "regex-2024.4.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a22ae1cfd82e4ffa2066eb3390777dc79468f866f0625261a93e44cdf6482b"},
- {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd24fd140b69f0b0bcc9165c397e9b2e89ecbeda83303abf2a072609f60239e2"},
- {file = "regex-2024.4.28-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:39fb166d2196413bead229cd64a2ffd6ec78ebab83fff7d2701103cf9f4dfd26"},
- {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9301cc6db4d83d2c0719f7fcda37229691745168bf6ae849bea2e85fc769175d"},
- {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c3d389e8d76a49923683123730c33e9553063d9041658f23897f0b396b2386f"},
- {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:99ef6289b62042500d581170d06e17f5353b111a15aa6b25b05b91c6886df8fc"},
- {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b91d529b47798c016d4b4c1d06cc826ac40d196da54f0de3c519f5a297c5076a"},
- {file = "regex-2024.4.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:43548ad74ea50456e1c68d3c67fff3de64c6edb85bcd511d1136f9b5376fc9d1"},
- {file = "regex-2024.4.28-cp39-cp39-win32.whl", hash = "sha256:05d9b6578a22db7dedb4df81451f360395828b04f4513980b6bd7a1412c679cc"},
- {file = "regex-2024.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:3986217ec830c2109875be740531feb8ddafe0dfa49767cdcd072ed7e8927962"},
- {file = "regex-2024.4.28.tar.gz", hash = "sha256:83ab366777ea45d58f72593adf35d36ca911ea8bd838483c1823b883a121b0e4"},
+ {file = "regex-2024.5.10-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:eda3dd46df535da787ffb9036b5140f941ecb91701717df91c9daf64cabef953"},
+ {file = "regex-2024.5.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1d5bd666466c8f00a06886ce1397ba8b12371c1f1c6d1bef11013e9e0a1464a8"},
+ {file = "regex-2024.5.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:32e5f3b8e32918bfbdd12eca62e49ab3031125c454b507127ad6ecbd86e62fca"},
+ {file = "regex-2024.5.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:534efd2653ebc4f26fc0e47234e53bf0cb4715bb61f98c64d2774a278b58c846"},
+ {file = "regex-2024.5.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:193b7c6834a06f722f0ce1ba685efe80881de7c3de31415513862f601097648c"},
+ {file = "regex-2024.5.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:160ba087232c5c6e2a1e7ad08bd3a3f49b58c815be0504d8c8aacfb064491cd8"},
+ {file = "regex-2024.5.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951be1eae7b47660412dc4938777a975ebc41936d64e28081bf2e584b47ec246"},
+ {file = "regex-2024.5.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8a0f0ab5453e409586b11ebe91c672040bc804ca98d03a656825f7890cbdf88"},
+ {file = "regex-2024.5.10-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9e6d4d6ae1827b2f8c7200aaf7501c37cf3f3896c86a6aaf2566448397c823dd"},
+ {file = "regex-2024.5.10-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:161a206c8f3511e2f5fafc9142a2cc25d7fe9a1ec5ad9b4ad2496a7c33e1c5d2"},
+ {file = "regex-2024.5.10-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:44b3267cea873684af022822195298501568ed44d542f9a2d9bebc0212e99069"},
+ {file = "regex-2024.5.10-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:560278c9975694e1f0bc50da187abf2cdc1e4890739ea33df2bc4a85eeef143e"},
+ {file = "regex-2024.5.10-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:70364a097437dd0a90b31cd77f09f7387ad9ac60ef57590971f43b7fca3082a5"},
+ {file = "regex-2024.5.10-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:42be5de7cc8c1edac55db92d82b68dc8e683b204d6f5414c5a51997a323d7081"},
+ {file = "regex-2024.5.10-cp310-cp310-win32.whl", hash = "sha256:9a8625849387b9d558d528e263ecc9c0fbde86cfa5c2f0eef43fff480ae24d71"},
+ {file = "regex-2024.5.10-cp310-cp310-win_amd64.whl", hash = "sha256:903350bf44d7e4116b4d5898b30b15755d61dcd3161e3413a49c7db76f0bee5a"},
+ {file = "regex-2024.5.10-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bf9596cba92ce7b1fd32c7b07c6e3212c7eed0edc271757e48bfcd2b54646452"},
+ {file = "regex-2024.5.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:45cc13d398b6359a7708986386f72bd156ae781c3e83a68a6d4cee5af04b1ce9"},
+ {file = "regex-2024.5.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad45f3bccfcb00868f2871dce02a755529838d2b86163ab8a246115e80cfb7d6"},
+ {file = "regex-2024.5.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33d19f0cde6838c81acffff25c7708e4adc7dd02896c9ec25c3939b1500a1778"},
+ {file = "regex-2024.5.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0a9f89d7db5ef6bdf53e5cc8e6199a493d0f1374b3171796b464a74ebe8e508a"},
+ {file = "regex-2024.5.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c6c71cf92b09e5faa72ea2c68aa1f61c9ce11cb66fdc5069d712f4392ddfd00"},
+ {file = "regex-2024.5.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7467ad8b0eac0b28e52679e972b9b234b3de0ea5cee12eb50091d2b68145fe36"},
+ {file = "regex-2024.5.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc0db93ad039fc2fe32ccd3dd0e0e70c4f3d6e37ae83f0a487e1aba939bd2fbd"},
+ {file = "regex-2024.5.10-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fa9335674d7c819674467c7b46154196c51efbaf5f5715187fd366814ba3fa39"},
+ {file = "regex-2024.5.10-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7dda3091838206969c2b286f9832dff41e2da545b99d1cfaea9ebd8584d02708"},
+ {file = "regex-2024.5.10-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:504b5116e2bd1821efd815941edff7535e93372a098e156bb9dffde30264e798"},
+ {file = "regex-2024.5.10-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:91b53dea84415e8115506cc62e441a2b54537359c63d856d73cb1abe05af4c9a"},
+ {file = "regex-2024.5.10-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1a3903128f9e17a500618e80c68165c78c741ebb17dd1a0b44575f92c3c68b02"},
+ {file = "regex-2024.5.10-cp311-cp311-win32.whl", hash = "sha256:236cace6c1903effd647ed46ce6dd5d76d54985fc36dafc5256032886736c85d"},
+ {file = "regex-2024.5.10-cp311-cp311-win_amd64.whl", hash = "sha256:12446827f43c7881decf2c126762e11425de5eb93b3b0d8b581344c16db7047a"},
+ {file = "regex-2024.5.10-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:14905ed75c7a6edf423eb46c213ed3f4507c38115f1ed3c00f4ec9eafba50e58"},
+ {file = "regex-2024.5.10-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4fad420b14ae1970a1f322e8ae84a1d9d89375eb71e1b504060ab2d1bfe68f3c"},
+ {file = "regex-2024.5.10-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c46a76a599fcbf95f98755275c5527304cc4f1bb69919434c1e15544d7052910"},
+ {file = "regex-2024.5.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0faecb6d5779753a6066a3c7a0471a8d29fe25d9981ca9e552d6d1b8f8b6a594"},
+ {file = "regex-2024.5.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aab65121229c2ecdf4a31b793d99a6a0501225bd39b616e653c87b219ed34a49"},
+ {file = "regex-2024.5.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:50e7e96a527488334379e05755b210b7da4a60fc5d6481938c1fa053e0c92184"},
+ {file = "regex-2024.5.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba034c8db4b264ef1601eb33cd23d87c5013b8fb48b8161debe2e5d3bd9156b0"},
+ {file = "regex-2024.5.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:031219782d97550c2098d9a68ce9e9eaefe67d2d81d8ff84c8354f9c009e720c"},
+ {file = "regex-2024.5.10-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:62b5f7910b639f3c1d122d408421317c351e213ca39c964ad4121f27916631c6"},
+ {file = "regex-2024.5.10-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cd832bd9b6120d6074f39bdfbb3c80e416848b07ac72910f1c7f03131a6debc3"},
+ {file = "regex-2024.5.10-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:e91b1976358e17197157b405cab408a5f4e33310cda211c49fc6da7cffd0b2f0"},
+ {file = "regex-2024.5.10-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:571452362d552de508c37191b6abbbb660028b8b418e2d68c20779e0bc8eaaa8"},
+ {file = "regex-2024.5.10-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5253dcb0bfda7214523de58b002eb0090cb530d7c55993ce5f6d17faf953ece7"},
+ {file = "regex-2024.5.10-cp312-cp312-win32.whl", hash = "sha256:2f30a5ab8902f93930dc6f627c4dd5da2703333287081c85cace0fc6e21c25af"},
+ {file = "regex-2024.5.10-cp312-cp312-win_amd64.whl", hash = "sha256:3799e36d60a35162bb35b2246d8bb012192b7437dff807ef79c14e7352706306"},
+ {file = "regex-2024.5.10-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:bbdc5db2c98ac2bf1971ffa1410c87ca7a15800415f788971e8ba8520fc0fda9"},
+ {file = "regex-2024.5.10-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6ccdeef4584450b6f0bddd5135354908dacad95425fcb629fe36d13e48b60f32"},
+ {file = "regex-2024.5.10-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:29d839829209f3c53f004e1de8c3113efce6d98029f044fa5cfee666253ee7e6"},
+ {file = "regex-2024.5.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0709ba544cf50bd5cb843df4b8bb6701bae2b70a8e88da9add8386cbca5c1385"},
+ {file = "regex-2024.5.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:972b49f2fe1047b9249c958ec4fa1bdd2cf8ce305dc19d27546d5a38e57732d8"},
+ {file = "regex-2024.5.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9cdbb1998da94607d5eec02566b9586f0e70d6438abf1b690261aac0edda7ab6"},
+ {file = "regex-2024.5.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7c8ee4861d9ef5b1120abb75846828c811f932d63311596ad25fa168053e00"},
+ {file = "regex-2024.5.10-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d35d4cc9270944e95f9c88af757b0c9fc43f396917e143a5756608462c5223b"},
+ {file = "regex-2024.5.10-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8722f72068b3e1156a4b2e1afde6810f1fc67155a9fa30a4b9d5b4bc46f18fb0"},
+ {file = "regex-2024.5.10-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:696639a73ca78a380acfaa0a1f6dd8220616a99074c05bba9ba8bb916914b224"},
+ {file = "regex-2024.5.10-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ea057306ab469130167014b662643cfaed84651c792948891d003cf0039223a5"},
+ {file = "regex-2024.5.10-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:b43b78f9386d3d932a6ce5af4b45f393d2e93693ee18dc4800d30a8909df700e"},
+ {file = "regex-2024.5.10-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c43395a3b7cc9862801a65c6994678484f186ce13c929abab44fb8a9e473a55a"},
+ {file = "regex-2024.5.10-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0bc94873ba11e34837bffd7e5006703abeffc4514e2f482022f46ce05bd25e67"},
+ {file = "regex-2024.5.10-cp38-cp38-win32.whl", hash = "sha256:1118ba9def608250250f4b3e3f48c62f4562ba16ca58ede491b6e7554bfa09ff"},
+ {file = "regex-2024.5.10-cp38-cp38-win_amd64.whl", hash = "sha256:458d68d34fb74b906709735c927c029e62f7d06437a98af1b5b6258025223210"},
+ {file = "regex-2024.5.10-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:15e593386ec6331e0ab4ac0795b7593f02ab2f4b30a698beb89fbdc34f92386a"},
+ {file = "regex-2024.5.10-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ca23b41355ba95929e9505ee04e55495726aa2282003ed9b012d86f857d3e49b"},
+ {file = "regex-2024.5.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2c8982ee19ccecabbaeac1ba687bfef085a6352a8c64f821ce2f43e6d76a9298"},
+ {file = "regex-2024.5.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7117cb7d6ac7f2e985f3d18aa8a1728864097da1a677ffa69e970ca215baebf1"},
+ {file = "regex-2024.5.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b66421f8878a0c82fc0c272a43e2121c8d4c67cb37429b764f0d5ad70b82993b"},
+ {file = "regex-2024.5.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:224a9269f133564109ce668213ef3cb32bc72ccf040b0b51c72a50e569e9dc9e"},
+ {file = "regex-2024.5.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab98016541543692a37905871a5ffca59b16e08aacc3d7d10a27297b443f572d"},
+ {file = "regex-2024.5.10-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:51d27844763c273a122e08a3e86e7aefa54ee09fb672d96a645ece0454d8425e"},
+ {file = "regex-2024.5.10-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:853cc36e756ff673bf984e9044ccc8fad60b95a748915dddeab9488aea974c73"},
+ {file = "regex-2024.5.10-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4e7eaf9df15423d07b6050fb91f86c66307171b95ea53e2d87a7993b6d02c7f7"},
+ {file = "regex-2024.5.10-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:169fd0acd7a259f58f417e492e93d0e15fc87592cd1e971c8c533ad5703b5830"},
+ {file = "regex-2024.5.10-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:334b79ce9c08f26b4659a53f42892793948a613c46f1b583e985fd5a6bf1c149"},
+ {file = "regex-2024.5.10-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:f03b1dbd4d9596dd84955bb40f7d885204d6aac0d56a919bb1e0ff2fb7e1735a"},
+ {file = "regex-2024.5.10-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfa6d61a76c77610ba9274c1a90a453062bdf6887858afbe214d18ad41cf6bde"},
+ {file = "regex-2024.5.10-cp39-cp39-win32.whl", hash = "sha256:249fbcee0a277c32a3ce36d8e36d50c27c968fdf969e0fbe342658d4e010fbc8"},
+ {file = "regex-2024.5.10-cp39-cp39-win_amd64.whl", hash = "sha256:0ce56a923f4c01d7568811bfdffe156268c0a7aae8a94c902b92fe34c4bde785"},
+ {file = "regex-2024.5.10.tar.gz", hash = "sha256:304e7e2418146ae4d0ef0e9ffa28f881f7874b45b4994cc2279b21b6e7ae50c8"},
]
[[package]]
@@ -3233,30 +3141,6 @@ typing-extensions = ">=4.5"
[package.extras]
tests = ["pytest (==7.1.2)"]
-[[package]]
-name = "robomimic"
-version = "0.2.0"
-description = "robomimic: A Modular Framework for Robot Learning from Demonstration"
-optional = false
-python-versions = ">=3"
-files = [
- {file = "robomimic-0.2.0.tar.gz", hash = "sha256:ee3bb5cf9c3e1feead6b57b43c5db738fd0a8e0c015fdf6419808af8fffdc463"},
-]
-
-[package.dependencies]
-egl_probe = ">=1.0.1"
-h5py = "*"
-imageio = "*"
-imageio-ffmpeg = "*"
-numpy = ">=1.13.3"
-psutil = "*"
-tensorboard = "*"
-tensorboardX = "*"
-termcolor = "*"
-torch = "*"
-torchvision = "*"
-tqdm = "*"
-
[[package]]
name = "safetensors"
version = "0.4.3"
@@ -3381,51 +3265,46 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"]
[[package]]
name = "scikit-image"
-version = "0.22.0"
+version = "0.23.2"
description = "Image processing in Python"
optional = true
-python-versions = ">=3.9"
+python-versions = ">=3.10"
files = [
- {file = "scikit_image-0.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74ec5c1d4693506842cc7c9487c89d8fc32aed064e9363def7af08b8f8cbb31d"},
- {file = "scikit_image-0.22.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:a05ae4fe03d802587ed8974e900b943275548cde6a6807b785039d63e9a7a5ff"},
- {file = "scikit_image-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a92dca3d95b1301442af055e196a54b5a5128c6768b79fc0a4098f1d662dee6"},
- {file = "scikit_image-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3663d063d8bf2fb9bdfb0ca967b9ee3b6593139c860c7abc2d2351a8a8863938"},
- {file = "scikit_image-0.22.0-cp310-cp310-win_amd64.whl", hash = "sha256:ebdbdc901bae14dab637f8d5c99f6d5cc7aaf4a3b6f4003194e003e9f688a6fc"},
- {file = "scikit_image-0.22.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95d6da2d8a44a36ae04437c76d32deb4e3c993ffc846b394b9949fd8ded73cb2"},
- {file = "scikit_image-0.22.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:2c6ef454a85f569659b813ac2a93948022b0298516b757c9c6c904132be327e2"},
- {file = "scikit_image-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87872f067444ee90a00dd49ca897208308645382e8a24bd3e76f301af2352cd"},
- {file = "scikit_image-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5c378db54e61b491b9edeefff87e49fcf7fdf729bb93c777d7a5f15d36f743e"},
- {file = "scikit_image-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:2bcb74adb0634258a67f66c2bb29978c9a3e222463e003b67ba12056c003971b"},
- {file = "scikit_image-0.22.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:003ca2274ac0fac252280e7179ff986ff783407001459ddea443fe7916e38cff"},
- {file = "scikit_image-0.22.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:cf3c0c15b60ae3e557a0c7575fbd352f0c3ce0afca562febfe3ab80efbeec0e9"},
- {file = "scikit_image-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5b23908dd4d120e6aecb1ed0277563e8cbc8d6c0565bdc4c4c6475d53608452"},
- {file = "scikit_image-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be79d7493f320a964f8fcf603121595ba82f84720de999db0fcca002266a549a"},
- {file = "scikit_image-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:722b970aa5da725dca55252c373b18bbea7858c1cdb406e19f9b01a4a73b30b2"},
- {file = "scikit_image-0.22.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:22318b35044cfeeb63ee60c56fc62450e5fe516228138f1d06c7a26378248a86"},
- {file = "scikit_image-0.22.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:9e801c44a814afdadeabf4dffdffc23733e393767958b82319706f5fa3e1eaa9"},
- {file = "scikit_image-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c472a1fb3665ec5c00423684590631d95f9afcbc97f01407d348b821880b2cb3"},
- {file = "scikit_image-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b7a6c89e8d6252332121b58f50e1625c35f7d6a85489c0b6b7ee4f5155d547a"},
- {file = "scikit_image-0.22.0-cp39-cp39-win_amd64.whl", hash = "sha256:5071b8f6341bfb0737ab05c8ab4ac0261f9e25dbcc7b5d31e5ed230fd24a7929"},
- {file = "scikit_image-0.22.0.tar.gz", hash = "sha256:018d734df1d2da2719087d15f679d19285fce97cd37695103deadfaef2873236"},
+ {file = "scikit_image-0.23.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f9a8db6c52f8d0e1474ea8320d7b8db442b4d6baa29dd0acbd02f8a49572f18a"},
+ {file = "scikit_image-0.23.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:524b51a7440e46ed2ebbde7bc288bf2dde1dee2caafdd9513b2aca38a48223b7"},
+ {file = "scikit_image-0.23.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b335c229170d787b3fb8c60d220f72049ccf862d5191a3cfda6ac84b995ac4e"},
+ {file = "scikit_image-0.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08b10781efbd6b084f3c847ff4049b657241ea866b9e331b14bf791dcb3e6661"},
+ {file = "scikit_image-0.23.2-cp310-cp310-win_amd64.whl", hash = "sha256:a207352e9a1956dda1424bbe872c7795345187138118e8be6a421aef3b988c2a"},
+ {file = "scikit_image-0.23.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee83fdb1843ee938eabdfeb9498623282935ea30aa20dffc5d5d16698efb4b2a"},
+ {file = "scikit_image-0.23.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:a158f50d3df4867bbd1c698520ede8bc493e430ad83f54ac1f0d8f57b328779b"},
+ {file = "scikit_image-0.23.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55de3326be124334b89314e9e04c8971ad98d6681e11a243f71bfb85ef9554b0"},
+ {file = "scikit_image-0.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fce619a6d84fe40c1208fa579b646e93ce13ef0afc3652a23e9782b2c183291a"},
+ {file = "scikit_image-0.23.2-cp311-cp311-win_amd64.whl", hash = "sha256:ee65669aa586e110346f567ed5c92d1bd63799a19e951cb83da3f54b0caf7c52"},
+ {file = "scikit_image-0.23.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:15bfb4e8d7bd90a967e6a3c3ab6be678063fc45e950b730684a8db46a02ff892"},
+ {file = "scikit_image-0.23.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5736e66d01b11cd90988ec24ab929c80a03af28f690189c951886891ebf63154"},
+ {file = "scikit_image-0.23.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3597ac5d8f51dafbcb7433ef1fdefdefb535f50745b2002ae0a5d651df4f063b"},
+ {file = "scikit_image-0.23.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1978be2abe3c3c3189a99a411d48bbb1306f7c2debb3aefbf426e23947f26623"},
+ {file = "scikit_image-0.23.2-cp312-cp312-win_amd64.whl", hash = "sha256:ae32bf0cb02b672ed74d28880ca6f88928ae8dd794d67e04fa3ff4836feb9bd6"},
+ {file = "scikit_image-0.23.2.tar.gz", hash = "sha256:c9da4b2c3117e3e30364a3d14496ee5c72b09eb1a4ab1292b302416faa360590"},
]
[package.dependencies]
-imageio = ">=2.27"
-lazy_loader = ">=0.3"
+imageio = ">=2.33"
+lazy-loader = ">=0.4"
networkx = ">=2.8"
-numpy = ">=1.22"
+numpy = ">=1.23"
packaging = ">=21"
-pillow = ">=9.0.1"
-scipy = ">=1.8"
+pillow = ">=9.1"
+scipy = ">=1.9"
tifffile = ">=2022.8.12"
[package.extras]
-build = ["Cython (>=0.29.32)", "build", "meson-python (>=0.14)", "ninja", "numpy (>=1.22)", "packaging (>=21)", "pythran", "setuptools (>=67)", "spin (==0.6)", "wheel"]
+build = ["Cython (>=3.0.4)", "build", "meson-python (>=0.15)", "ninja", "numpy (>=2.0.0rc1)", "packaging (>=21)", "pythran", "setuptools (>=67)", "spin (==0.8)", "wheel"]
data = ["pooch (>=1.6.0)"]
-developer = ["pre-commit", "tomli"]
-docs = ["PyWavelets (>=1.1.1)", "dask[array] (>=2022.9.2)", "ipykernel", "ipywidgets", "kaleido", "matplotlib (>=3.5)", "myst-parser", "numpydoc (>=1.6)", "pandas (>=1.5)", "plotly (>=5.10)", "pooch (>=1.6)", "pydata-sphinx-theme (>=0.14.1)", "pytest-runner", "scikit-learn (>=1.1)", "seaborn (>=0.11)", "sphinx (>=7.2)", "sphinx-copybutton", "sphinx-gallery (>=0.14)", "sphinx_design (>=0.5)", "tifffile (>=2022.8.12)"]
-optional = ["PyWavelets (>=1.1.1)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=0.2.1)", "dask[array] (>=2021.1.0)", "matplotlib (>=3.5)", "pooch (>=1.6.0)", "pyamg", "scikit-learn (>=1.1)"]
-test = ["asv", "matplotlib (>=3.5)", "numpydoc (>=1.5)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-cov (>=2.11.0)", "pytest-faulthandler", "pytest-localserver"]
+developer = ["ipython", "pre-commit", "tomli"]
+docs = ["PyWavelets (>=1.1.1)", "dask[array] (>=2022.9.2)", "ipykernel", "ipywidgets", "kaleido", "matplotlib (>=3.6)", "myst-parser", "numpydoc (>=1.7)", "pandas (>=1.5)", "plotly (>=5.10)", "pooch (>=1.6)", "pydata-sphinx-theme (>=0.15.2)", "pytest-doctestplus", "pytest-runner", "scikit-learn (>=1.1)", "seaborn (>=0.11)", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-gallery (>=0.14)", "sphinx_design (>=0.5)", "tifffile (>=2022.8.12)"]
+optional = ["PyWavelets (>=1.1.1)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=0.2.1)", "dask[array] (>=2021.1.0)", "matplotlib (>=3.6)", "pooch (>=1.6.0)", "pyamg", "scikit-learn (>=1.1)"]
+test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-cov (>=2.11.0)", "pytest-doctestplus", "pytest-faulthandler", "pytest-localserver"]
[[package]]
name = "scipy"
@@ -3471,13 +3350,13 @@ test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "po
[[package]]
name = "sentry-sdk"
-version = "2.0.1"
+version = "2.1.1"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = ">=3.6"
files = [
- {file = "sentry_sdk-2.0.1-py2.py3-none-any.whl", hash = "sha256:b54c54a2160f509cf2757260d0cf3885b608c6192c2555a3857e3a4d0f84bdb3"},
- {file = "sentry_sdk-2.0.1.tar.gz", hash = "sha256:c278e0f523f6f0ee69dc43ad26dcdb1202dffe5ac326ae31472e012d941bee21"},
+ {file = "sentry_sdk-2.1.1-py2.py3-none-any.whl", hash = "sha256:99aeb78fb76771513bd3b2829d12613130152620768d00cd3e45ac00cb17950f"},
+ {file = "sentry_sdk-2.1.1.tar.gz", hash = "sha256:95d8c0bb41c8b0bc37ab202c2c4a295bb84398ee05f4cdce55051cd75b926ec1"},
]
[package.dependencies]
@@ -3486,6 +3365,7 @@ urllib3 = ">=1.26.11"
[package.extras]
aiohttp = ["aiohttp (>=3.5)"]
+anthropic = ["anthropic (>=0.16)"]
arq = ["arq (>=0.23)"]
asyncpg = ["asyncpg (>=0.23)"]
beam = ["apache-beam (>=2.12)"]
@@ -3501,6 +3381,8 @@ flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
grpcio = ["grpcio (>=1.21.1)"]
httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
+huggingface-hub = ["huggingface-hub (>=0.22)"]
+langchain = ["langchain (>=0.0.210)"]
loguru = ["loguru (>=0.5)"]
openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
@@ -3749,55 +3631,6 @@ files = [
{file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
]
-[[package]]
-name = "tensorboard"
-version = "2.16.2"
-description = "TensorBoard lets you watch Tensors Flow"
-optional = false
-python-versions = ">=3.9"
-files = [
- {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"},
-]
-
-[package.dependencies]
-absl-py = ">=0.4"
-grpcio = ">=1.48.2"
-markdown = ">=2.6.8"
-numpy = ">=1.12.0"
-protobuf = ">=3.19.6,<4.24.0 || >4.24.0"
-setuptools = ">=41.0.0"
-six = ">1.9"
-tensorboard-data-server = ">=0.7.0,<0.8.0"
-werkzeug = ">=1.0.1"
-
-[[package]]
-name = "tensorboard-data-server"
-version = "0.7.2"
-description = "Fast data loading for TensorBoard"
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
- {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
- {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
-]
-
-[[package]]
-name = "tensorboardx"
-version = "2.6.2.2"
-description = "TensorBoardX lets you watch Tensors Flow without Tensorflow"
-optional = false
-python-versions = "*"
-files = [
- {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"},
- {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"},
-]
-
-[package.dependencies]
-numpy = "*"
-packaging = "*"
-protobuf = ">=3.20"
-
[[package]]
name = "termcolor"
version = "2.4.0"
@@ -3814,13 +3647,13 @@ tests = ["pytest", "pytest-cov"]
[[package]]
name = "tifffile"
-version = "2024.4.24"
+version = "2024.5.10"
description = "Read and write TIFF files"
optional = true
python-versions = ">=3.9"
files = [
- {file = "tifffile-2024.4.24-py3-none-any.whl", hash = "sha256:8d0b982f4b01ace358835ae6c2beb5a70cb7287f5d3a2e96c318bd5befa97b1f"},
- {file = "tifffile-2024.4.24.tar.gz", hash = "sha256:e329f36ac8ff3bbe7dd04609340be26b03c4b9e9a69235fc3ab33434157c38ea"},
+ {file = "tifffile-2024.5.10-py3-none-any.whl", hash = "sha256:4154f091aa24d4e75bfad9ab2d5424a68c70e67b8220188066dc61946d4551bd"},
+ {file = "tifffile-2024.5.10.tar.gz", hash = "sha256:aa1e1b12be952ab20717d6848bd6d4a5ee88d2aa319f1152bff4354ad728ec86"},
]
[package.dependencies]
@@ -3933,13 +3766,13 @@ scipy = ["scipy"]
[[package]]
name = "tqdm"
-version = "4.66.2"
+version = "4.66.4"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
- {file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"},
- {file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"},
+ {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"},
+ {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"},
]
[package.dependencies]
@@ -4035,59 +3868,46 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
[[package]]
name = "wandb"
-version = "0.16.6"
+version = "0.17.0"
description = "A CLI and library for interacting with the Weights & Biases API."
optional = false
python-versions = ">=3.7"
files = [
- {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"},
- {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"},
+ {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"},
+ {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"},
+ {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"},
+ {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"},
+ {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"},
+ {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"},
+ {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"},
]
[package.dependencies]
-appdirs = ">=1.4.3"
-Click = ">=7.1,<8.0.0 || >8.0.0"
+click = ">=7.1,<8.0.0 || >8.0.0"
docker-pycreds = ">=0.4.0"
-GitPython = ">=1.0.0,<3.1.29 || >3.1.29"
+gitpython = ">=1.0.0,<3.1.29 || >3.1.29"
+platformdirs = "*"
protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""}
psutil = ">=5.0.0"
-PyYAML = "*"
+pyyaml = "*"
requests = ">=2.0.0,<3"
sentry-sdk = ">=1.0.0"
setproctitle = "*"
setuptools = "*"
[package.extras]
-async = ["httpx (>=0.23.0)"]
aws = ["boto3"]
azure = ["azure-identity", "azure-storage-blob"]
gcp = ["google-cloud-storage"]
importers = ["filelock", "mlflow", "polars", "rich", "tenacity"]
kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
-launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"]
+launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "typing-extensions"]
media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"]
models = ["cloudpickle"]
perf = ["orjson"]
reports = ["pydantic (>=2.0.0)"]
sweeps = ["sweeps (>=0.2.0)"]
-[[package]]
-name = "werkzeug"
-version = "3.0.2"
-description = "The comprehensive WSGI web application library."
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "werkzeug-3.0.2-py3-none-any.whl", hash = "sha256:3aac3f5da756f93030740bc235d3e09449efcf65f2f55e3602e1d851b8f48795"},
- {file = "werkzeug-3.0.2.tar.gz", hash = "sha256:e39b645a6ac92822588e7b39a692e7828724ceae0b0d702ef96701f90e70128d"},
-]
-
-[package.dependencies]
-MarkupSafe = ">=2.1.1"
-
-[package.extras]
-watchdog = ["watchdog (>=2.3)"]
-
[[package]]
name = "xxhash"
version = "3.4.1"
@@ -4310,13 +4130,13 @@ multidict = ">=4.0"
[[package]]
name = "zarr"
-version = "2.17.2"
+version = "2.18.0"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false
python-versions = ">=3.9"
files = [
- {file = "zarr-2.17.2-py3-none-any.whl", hash = "sha256:70d7cc07c24280c380ef80644151d136b7503b0d83c9f214e8000ddc0f57f69b"},
- {file = "zarr-2.17.2.tar.gz", hash = "sha256:2cbaa6cb4e342d45152d4a7a4b2013c337fcd3a8e7bc98253560180de60552ce"},
+ {file = "zarr-2.18.0-py3-none-any.whl", hash = "sha256:7f8532b6a3f50f22e809e130e09353637ec8b5bb5e95a5a0bfaae91f63978b5d"},
+ {file = "zarr-2.18.0.tar.gz", hash = "sha256:c3b7d2c85b8a42b0ad0ad268a36fb6886ca852098358c125c6b126a417e0a598"},
]
[package.dependencies]
@@ -4355,4 +4175,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
-content-hash = "d2066576dc4aebaf623c295fe626bf6805fd2ec26a6ba47fa5415204994aa922"
+content-hash = "e3e3c306a5519e4f716a1ac086ad9b734efedcac077a0ec71e5bc16349a1e559"
diff --git a/pyproject.toml b/pyproject.toml
index 324288452..5b80d06f5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,9 +4,9 @@ version = "0.1.0"
description = "π€ LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
authors = [
"Rémi Cadène ",
+ "Simon Alibert ",
"Alexander Soare ",
"Quentin GallouΓ©dec ",
- "Simon Alibert ",
"Adil Zouitine ",
"Thomas Wolf ",
]
@@ -28,37 +28,36 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
-termcolor = "^2.4.0"
-omegaconf = "^2.3.0"
-wandb = "^0.16.3"
-imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
-gdown = "^5.1.0"
-hydra-core = "^1.3.2"
-einops = "^0.8.0"
-pymunk = "^6.6.0"
-zarr = "^2.17.0"
-numba = "^0.59.0"
+termcolor = ">=2.4.0"
+omegaconf = ">=2.3.0"
+wandb = ">=0.16.3"
+imageio = {extras = ["ffmpeg"], version = ">=2.34.0"}
+gdown = ">=5.1.0"
+hydra-core = ">=1.3.2"
+einops = ">=0.8.0"
+pymunk = ">=6.6.0"
+zarr = ">=2.17.0"
+numba = ">=0.59.0"
torch = "^2.2.1"
-opencv-python = "^4.9.0.80"
+opencv-python = ">=4.9.0"
diffusers = "^0.27.2"
-torchvision = "^0.18.0"
-h5py = "^3.10.0"
-huggingface-hub = "^0.21.4"
-robomimic = "0.2.0"
-gymnasium = "^0.29.1"
-cmake = "^3.29.0.1"
-gym-pusht = { version = "^0.1.0", optional = true}
-gym-xarm = { version = "^0.1.0", optional = true}
-gym-aloha = { version = "^0.1.0", optional = true}
-pre-commit = {version = "^3.7.0", optional = true}
-debugpy = {version = "^1.8.1", optional = true}
-pytest = {version = "^8.1.0", optional = true}
-pytest-cov = {version = "^5.0.0", optional = true}
+torchvision = ">=0.18.0"
+h5py = ">=3.10.0"
+huggingface-hub = ">=0.21.4"
+gymnasium = ">=0.29.1"
+cmake = ">=3.29.0.1"
+gym-pusht = { version = ">=0.1.3", optional = true}
+gym-xarm = { version = ">=0.1.1", optional = true}
+gym-aloha = { version = ">=0.1.1", optional = true}
+pre-commit = {version = ">=3.7.0", optional = true}
+debugpy = {version = ">=1.8.1", optional = true}
+pytest = {version = ">=8.1.0", optional = true}
+pytest-cov = {version = ">=5.0.0", optional = true}
datasets = "^2.19.0"
-imagecodecs = { version = "^2024.1.1", optional = true }
-pyav = "^12.0.5"
-moviepy = "^1.0.3"
-rerun-sdk = "^0.15.1"
+imagecodecs = { version = ">=2024.1.1", optional = true }
+pyav = ">=12.0.5"
+moviepy = ">=1.0.3"
+rerun-sdk = ">=0.15.1"
[tool.poetry.extras]
@@ -104,5 +103,5 @@ ignore-init-module-imports = true
[build-system]
-requires = ["poetry-core>=1.5.0"]
+requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
diff --git a/tests/conftest.py b/tests/conftest.py
index 856ca4555..62f831aa3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from .utils import DEVICE
diff --git a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors
index 70c9b6d81..3c9447d7f 100644
Binary files a/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/actions.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors
index 2e8451891..7dfbc3b35 100644
Binary files a/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/grad_stats.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors
index e8d537c82..4c738f397 100644
Binary files a/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/output_dict.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors
index 6e33879f3..7a2e0e70e 100644
Binary files a/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors and b/tests/data/save_policy_to_safetensors/aloha_act/param_stats.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors
index 730f5b2bc..8f0399035 100644
Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/actions.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors
index 1b1142b2d..2b6593960 100644
Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/grad_stats.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors
index 77472bb50..a9f61b36e 100644
Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/output_dict.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors
index ade6a9e03..a9f4608f6 100644
Binary files a/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors and b/tests/data/save_policy_to_safetensors/pusht_diffusion/param_stats.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors
new file mode 100644
index 000000000..0339ca0e3
Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/actions.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors
new file mode 100644
index 000000000..5520c643c
Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/grad_stats.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors
new file mode 100644
index 000000000..2321f31c7
Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/output_dict.safetensors differ
diff --git a/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors
new file mode 100644
index 000000000..5e8a69470
Binary files /dev/null and b/tests/data/save_policy_to_safetensors/xarm_tdmpc/param_stats.safetensors differ
diff --git a/tests/scripts/save_dataset_to_safetensors.py b/tests/scripts/save_dataset_to_safetensors.py
index 17cf2b38f..554efe758 100644
--- a/tests/scripts/save_dataset_to_safetensors.py
+++ b/tests/scripts/save_dataset_to_safetensors.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
This script provides a utility for saving a dataset as safetensors files for the purpose of testing backward compatibility
when updating the data format. It uses the `PushtDataset` to create a DataLoader and saves selected frame from the
diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py
index 29e9a34f9..e79a94ff9 100644
--- a/tests/scripts/save_policy_to_safetensor.py
+++ b/tests/scripts/save_policy_to_safetensor.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import shutil
from pathlib import Path
diff --git a/tests/test_available.py b/tests/test_available.py
index e2dfeeb95..db5bd520a 100644
--- a/tests/test_available.py
+++ b/tests/test_available.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import importlib
import gymnasium as gym
@@ -15,7 +30,7 @@ from tests.utils import require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
- be sucessfully imported β if they're installed β and that their
+ be successfully imported β if they're installed β and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 23996abac..664545050 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import logging
from copy import deepcopy
@@ -41,7 +56,7 @@ def test_factory(env_name, repo_id, policy_name):
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
- image_keys = dataset.image_keys
+ camera_keys = dataset.camera_keys
item = dataset[0]
@@ -71,7 +86,7 @@ def test_factory(env_name, repo_id, policy_name):
else:
assert item[key].ndim == ndim, f"{key}"
- if key in image_keys:
+ if key in camera_keys:
assert item[key].dtype == torch.float32, f"{key}"
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert item[key].max() <= 1.0, f"{key}"
diff --git a/tests/test_envs.py b/tests/test_envs.py
index f172a6458..aec9999da 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import importlib
import gymnasium as gym
diff --git a/tests/test_examples.py b/tests/test_examples.py
index ccce4eb24..de95a9915 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -1,8 +1,25 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import subprocess
import sys
from pathlib import Path
+from tests.utils import require_package
+
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
for f, r in finds_and_replaces:
@@ -21,6 +38,7 @@ def test_example_1():
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
+@require_package("gym_pusht")
def test_examples_3_and_2():
"""
Train a model with example 3, check the outputs.
@@ -46,7 +64,7 @@ def test_examples_3_and_2():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
- for file_name in ["model.safetensors", "config.json", "config.yaml"]:
+ for file_name in ["model.safetensors", "config.json"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
@@ -58,16 +76,16 @@ def test_examples_3_and_2():
file_contents = _find_and_replace(
file_contents,
[
- ('pretrained_policy_name = "lerobot/diffusion_pusht"', ""),
- ("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
+ ('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
- ('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
- ('"eval.batch_size=10"', '"eval.batch_size=1"'),
- ('"device=cuda"', '"device=cpu"'),
+ ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
+ ("step += 1", "break"),
],
)
- assert Path("outputs/train/example_pusht_diffusion").exists()
+ exec(file_contents, {})
+
+ assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
diff --git a/tests/test_policies.py b/tests/test_policies.py
index 12beec92c..75633fe66 100644
--- a/tests/test_policies.py
+++ b/tests/test_policies.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import inspect
from pathlib import Path
@@ -49,6 +64,14 @@ def test_get_policy_and_config_classes(policy_name: str):
"act",
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
),
+ # Note: these parameters also need custom logic in the test function for overriding the Hydra config.
+ (
+ "aloha",
+ "diffusion",
+ ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
+ ),
+ # Note: these parameters also need custom logic in the test function for overriding the Hydra config.
+ ("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
],
)
@require_env
@@ -72,6 +95,31 @@ def test_policy(env_name, policy_name, extra_overrides):
+ extra_overrides,
)
+ # Additional config override logic.
+ if env_name == "aloha" and policy_name == "diffusion":
+ for keys in [
+ ("training", "delta_timestamps"),
+ ("policy", "input_shapes"),
+ ("policy", "input_normalization_modes"),
+ ]:
+ dct = dict(cfg[keys[0]][keys[1]])
+ dct["observation.images.top"] = dct["observation.image"]
+ del dct["observation.image"]
+ cfg[keys[0]][keys[1]] = dct
+ cfg.override_dataset_stats = None
+
+ # Additional config override logic.
+ if env_name == "pusht" and policy_name == "act":
+ for keys in [
+ ("policy", "input_shapes"),
+ ("policy", "input_normalization_modes"),
+ ]:
+ dct = dict(cfg[keys[0]][keys[1]])
+ dct["observation.image"] = dct["observation.images.top"]
+ del dct["observation.images.top"]
+ cfg[keys[0]][keys[1]] = dct
+ cfg.override_dataset_stats = None
+
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
@@ -236,7 +284,7 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
[
- # ("xarm", "tdmpc", ["policy.n_action_repeats=2"]),
+ ("xarm", "tdmpc", []),
(
"pusht",
"diffusion",
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 000000000..bcdd95b4e
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,38 @@
+import random
+from typing import Callable
+
+import numpy as np
+import pytest
+import torch
+
+from lerobot.common.utils.utils import seeded_context, set_global_seed
+
+
+@pytest.mark.parametrize(
+ "rand_fn",
+ [
+ random.random,
+ np.random.random,
+ lambda: torch.rand(1).item(),
+ ]
+ + [lambda: torch.rand(1, device="cuda")]
+ if torch.cuda.is_available()
+ else [],
+)
+def test_seeding(rand_fn: Callable[[], int]):
+ set_global_seed(0)
+ a = rand_fn()
+ with seeded_context(1337):
+ c = rand_fn()
+ b = rand_fn()
+ set_global_seed(0)
+ a_ = rand_fn()
+ b_ = rand_fn()
+ # Check that `set_global_seed` lets us reproduce a and b.
+ assert a_ == a
+ # Additionally, check that the `seeded_context` didn't interrupt the global RNG.
+ assert b_ == b
+ set_global_seed(1337)
+ c_ = rand_fn()
+ # Check that `seeded_context` and `global_seed` give the same reproducibility.
+ assert c_ == c
diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py
index 0124afd3f..999540402 100644
--- a/tests/test_visualize_dataset.py
+++ b/tests/test_visualize_dataset.py
@@ -1,3 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset
diff --git a/tests/utils.py b/tests/utils.py
index 6a7066942..ba49ee706 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,4 +1,20 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import platform
+from functools import wraps
import pytest
import torch
@@ -61,7 +77,6 @@ def require_env(func):
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
"""
- from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
@@ -82,3 +97,20 @@ def require_env(func):
return func(*args, **kwargs)
return wrapper
+
+
+def require_package(package_name):
+ """
+ Decorator that skips the test if the specified package is not installed.
+ """
+
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not is_package_available(package_name):
+ pytest.skip(f"{package_name} not installed")
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator