Compare commits

..

1 Commits

Author SHA1 Message Date
mshukor
86e8c1d997 supporting accelerate 2025-02-25 21:38:17 +01:00
102 changed files with 395 additions and 547 deletions

View File

@@ -1,125 +0,0 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
name: PR Style Bot
on:
issue_comment:
types: [created]
permissions:
contents: write
pull-requests: write
jobs:
run-style-bot:
if: >
contains(github.event.comment.body, '@bot /style') &&
github.event.issue.pull_request != null
runs-on: ubuntu-latest
steps:
- name: Extract PR details
id: pr_info
uses: actions/github-script@v6
with:
script: |
const prNumber = context.payload.issue.number;
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// We capture both the branch ref and the "full_name" of the head repo
// so that we can check out the correct repository & branch (including forks).
core.setOutput("prNumber", prNumber);
core.setOutput("headRef", pr.head.ref);
core.setOutput("headRepoFullName", pr.head.repo.full_name);
- name: Check out PR branch
uses: actions/checkout@v4
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
with:
persist-credentials: true
# Instead of checking out the base repo, use the contributor's repo name
repository: ${{ env.HEADREPOFULLNAME }}
ref: ${{ env.HEADREF }}
# You may need fetch-depth: 0 for being able to push
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Debug
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
run: |
echo "PR number: ${PRNUMBER}"
echo "Head Ref: ${HEADREF}"
echo "Head Repo Full Name: ${HEADREPOFULLNAME}"
- name: Set up Python
uses: actions/setup-python@v4
- name: Get Ruff Version from pre-commit-config.yaml
id: get-ruff-version
run: |
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
- name: Install Ruff
env:
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
run: python -m pip install "ruff==${RUFF_VERSION}"
- name: Ruff check
run: ruff check --fix
- name: Ruff format
run: ruff format
- name: Commit and push changes
id: commit_and_push
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
echo "HEADREPOFULLNAME: ${HEADREPOFULLNAME}, HEADREF: ${HEADREF}"
# Configure git with the Actions bot user
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Make sure your 'origin' remote is set to the contributor's fork
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
# If there are changes after running style/quality, commit them
if [ -n "$(git status --porcelain)" ]; then
git add .
git commit -m "Apply style fixes"
# Push to the original contributor's forked branch
git push origin HEAD:${HEADREF}
echo "changes_pushed=true" >> $GITHUB_OUTPUT
else
echo "No changes to commit."
echo "changes_pushed=false" >> $GITHUB_OUTPUT
fi
- name: Comment on PR with workflow run link
if: steps.commit_and_push.outputs.changes_pushed == 'true'
uses: actions/github-script@v6
with:
script: |
const prNumber = parseInt(process.env.prNumber, 10);
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
});
env:
prNumber: ${{ steps.pr_info.outputs.prNumber }}

View File

@@ -32,27 +32,13 @@ jobs:
id: get-ruff-version
run: |
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
- name: Install Ruff
env:
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
run: python -m pip install "ruff==${RUFF_VERSION}"
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
- name: Ruff check
run: ruff check --output-format=github
- name: Ruff format
run: ruff format --diff
typos:
name: Typos
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
with:
persist-credentials: false
- name: typos-action
uses: crate-ci/typos@v1.29.10

View File

@@ -43,7 +43,7 @@ jobs:
needs: get_changed_files
runs-on:
group: aws-general-8-plus
if: needs.get_changed_files.outputs.matrix != ''
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
strategy:
fail-fast: false
matrix:

View File

@@ -13,11 +13,6 @@ repos:
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/crate-ci/typos
rev: v1.29.10
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:

View File

@@ -228,7 +228,7 @@ Follow these steps to start contributing:
git commit
```
Note, if you already committed some changes that have a wrong formatting, you can use:
Note, if you already commited some changes that have a wrong formatting, you can use:
```bash
pre-commit run --all-files
```

View File

@@ -114,7 +114,7 @@ We tried to measure the most impactful parameters for both encoding and decoding
Additional encoding parameters exist that are not included in this benchmark. In particular:
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.

View File

@@ -1,29 +1,33 @@
# Configure image
ARG PYTHON_VERSION=3.10
FROM python:${PYTHON_VERSION}-slim
# Configure environment variables
ARG PYTHON_VERSION
ENV DEBIAN_FRONTEND=noninteractive
ENV MUJOCO_GL="egl"
ENV PATH="/opt/venv/bin:$PATH"
ARG DEBIAN_FRONTEND=noninteractive
# Install dependencies and set up Python in a single layer
# Install apt dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git \
build-essential cmake git git-lfs \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
&& python -m venv /opt/venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Clone repository and install LeRobot in a single layer
COPY . /lerobot
# Create virtual environment
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Install LeRobot
RUN git lfs install
RUN git clone https://github.com/huggingface/lerobot.git /lerobot
WORKDIR /lerobot
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
--extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \
--extra-index-url https://download.pytorch.org/whl/cpu
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"
# Execute in bash shell rather than python
CMD ["/bin/bash"]

View File

@@ -8,7 +8,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git \
build-essential cmake git git-lfs \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
@@ -18,7 +18,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Clone repository and install LeRobot in a single layer
COPY . /lerobot
WORKDIR /lerobot
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
RUN git lfs install \
&& git clone https://github.com/huggingface/lerobot.git . \
&& /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
&& /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]"

View File

@@ -185,7 +185,7 @@ sudo chmod 666 /dev/ttyACM1
#### d. Update config file
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
IMPORTANTLY: Now that you have your ports of leader and follower arm and ip adress of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like:
```python
@RobotConfig.register_subclass("lekiwi")
@dataclass
@@ -324,7 +324,7 @@ You should see on your laptop something like this: ```[INFO] Connected to remote
| F | Decrease speed |
> [!TIP]
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
> If you use a different keyboard you can change the keys for each commmand in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
## Troubleshoot communication

View File

@@ -2,7 +2,7 @@ This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-ro
## Source the parts
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already.
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.

View File

@@ -398,7 +398,7 @@ And here are the corresponding positions for the leader arm:
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to mesure if the values changed negatively or positively.
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
@@ -663,7 +663,7 @@ camera.disconnect()
**Instantiate your robot with cameras**
Additionally, you can set up your robot to work with your cameras.
Additionaly, you can set up your robot to work with your cameras.
Modify the following Python code with the appropriate camera names and configurations:
```python
@@ -825,8 +825,8 @@ It contains:
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchrously.
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchrously.
Troubleshooting:
- On Linux, if you encounter a hanging issue when using cameras, uninstall opencv and re-install it with conda:
@@ -846,7 +846,7 @@ At the end of data recording, your dataset will be uploaded on your Hugging Face
echo https://huggingface.co/datasets/${HF_USER}/koch_test
```
### b. Advice for recording dataset
### b. Advices for recording dataset
Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings.

View File

@@ -98,7 +98,7 @@ python lerobot/scripts/control_robot.py \
```
This is equivalent to running `stretch_robot_home.py`
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
> **Note:** If you run any of the LeRobot scripts below and Stretch is not poperly homed, it will automatically home/calibrate first.
**Teleoperate**
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).

View File

@@ -172,10 +172,10 @@ python lerobot/scripts/control_robot.py \
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constent 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
## More
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.

View File

@@ -92,7 +92,7 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
ep_ft_array = data # data is alreay a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array

View File

@@ -226,7 +226,7 @@ class LeRobotDatasetMetadata:
def add_task(self, task: str):
"""
Given a task in natural language, add it to the dictionary of tasks.
Given a task in natural language, add it to the dictionnary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
@@ -389,7 +389,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- info contains various information about the dataset like shapes, keys, fps etc.
- stats stores the dataset statistics of the different modalities for normalization
- tasks contains the prompts for each task of the dataset, which can be used for
task-conditioned training.
task-conditionned training.
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
@@ -848,7 +848,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionary
# Add new tasks to the tasks dictionnary
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:

View File

@@ -152,7 +152,7 @@ def download_raw(raw_dir: Path, repo_id: str):
stacklevel=1,
)
# Send warning if raw_dir isn't well formatted
# Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that

View File

@@ -68,9 +68,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
modality_df,
on="timestamp_utc",
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
# matching timestamps that are too far apart, in order to fit the backward constraints. It's not the case for "nearest".
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far apart.
# are too far appart.
direction="nearest",
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
)
@@ -126,7 +126,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
videos_dir.parent.mkdir(parents=True, exist_ok=True)
videos_dir.symlink_to((raw_dir / "videos").absolute())
# sanity check the video paths are well formatted
# sanity check the video paths are well formated
for key in df:
if "observation.images." not in key:
continue
@@ -143,7 +143,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
# sanity check the video path is well formatted
# sanity check the video path is well formated
video_path = videos_dir.parent / data_dict[key][0]["path"]
if not video_path.exists():
raise ValueError(f"Video file not found in {video_path}")

View File

@@ -17,7 +17,7 @@
For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datasets before running this script.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
Example:
python lerobot/scripts/push_dataset_to_hub.py \

View File

@@ -222,7 +222,7 @@ def load_episodes(local_dir: Path) -> dict:
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
@@ -445,10 +445,10 @@ def get_episode_data_index(
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
cumulative_lengths = list(accumulate(episode_lengths.values()))
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
"to": torch.LongTensor(cumulative_lengths),
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}

View File

@@ -31,7 +31,6 @@ from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
LOCAL_DIR = Path("data/")
# spellchecker:off
ALOHA_MOBILE_INFO = {
"robot_config": AlohaRobotConfig(),
"license": "mit",
@@ -857,7 +856,6 @@ DATASETS = {
}""").lstrip(),
},
}
# spellchecker:on
def batch_convert():

View File

@@ -17,7 +17,7 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task.

View File

@@ -73,7 +73,7 @@ def decode_video_frames_torchvision(
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)

View File

@@ -37,12 +37,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
Args:
cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
Raises:
ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not installed
ModuleNotFoundError: If the requested env package is not intalled
Returns:
gym.vector.VectorEnv: The parallelized gym.env instance.

View File

@@ -64,7 +64,7 @@ class ACTConfig(PreTrainedConfig):
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution.

View File

@@ -68,7 +68,7 @@ class DiffusionConfig(PreTrainedConfig):
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
@@ -99,7 +99,7 @@ class DiffusionConfig(PreTrainedConfig):
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See
`LeRobotDataset` and `load_previous_and_future_frames` for more information. Note, this defaults
`LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults
to False as the original Diffusion Policy implementation does the same.
"""

View File

@@ -2,7 +2,7 @@
Convert pi0 parameters from Jax to Pytorch
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
and install the required libraries.
and install the required librairies.
```bash
cd ~/code/openpi

View File

@@ -76,7 +76,7 @@ class TDMPCConfig(PreTrainedConfig):
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
be zero.
uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
trajectory values (this is the λ coefficient in eqn 4 of FOWM).
trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
@@ -165,7 +165,7 @@ class TDMPCConfig(PreTrainedConfig):
"""Input validation (not exhaustive)."""
if self.n_gaussian_samples <= 0:
raise ValueError(
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
raise ValueError(

View File

@@ -66,7 +66,7 @@ class VQBeTConfig(PreTrainedConfig):
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).

View File

@@ -485,7 +485,7 @@ class VQBeTHead(nn.Module):
def forward(self, x, **kwargs) -> dict:
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape
# we calculate N and T side parallelly. Thus, the dimensions would be
# we calculate N and T side parallely. Thus, the dimensions would be
# (batch size * number of action query tokens, action chunk size, action dimension)
x = einops.rearrange(x, "N T WA -> (N T) WA")
@@ -772,7 +772,7 @@ class VqVae(nn.Module):
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
The vq_layer uses residual VQs.
This class contains functions for training the encoder and decoder along with the residual VQ layer (for training phase 1),
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
as well as functions to help BeT training part in training phase 2.
"""

View File

@@ -38,7 +38,7 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
This file is part of a VQ-BeT that utilizes code from the following repositories:
- Vector Quantize PyTorch code is licensed under the MIT License:
Original source: https://github.com/lucidrains/vector-quantize-pytorch
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
Original source: https://github.com/karpathy/nanoGPT
@@ -289,7 +289,7 @@ class GPT(nn.Module):
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
Original source: https://github.com/lucidrains/vector-quantize-pytorch
Origianl source: https://github.com/lucidrains/vector-quantize-pytorch
- The vector-quantize-pytorch code is licensed under the MIT License:
@@ -1349,9 +1349,9 @@ class EuclideanCodebook(nn.Module):
# calculate distributed variance
variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
distributed.all_reduce(variance_number)
batch_variance = variance_number / num_vectors
variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
distributed.all_reduce(variance_numer)
batch_variance = variance_numer / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)

View File

@@ -66,7 +66,7 @@ class RecordControlConfig(ControlConfig):
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.

View File

@@ -242,7 +242,7 @@ class DriveMode(enum.Enum):
class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
LINEAR = 1
@@ -610,7 +610,7 @@ class DynamixelMotorsBus:
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
# Subtract the homing offsets to come back to actual motor range of values
# Substract the homing offsets to come back to actual motor range of values
# which can be arbitrary.
values[i] -= homing_offset

View File

@@ -221,7 +221,7 @@ class DriveMode(enum.Enum):
class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
LINEAR = 1
@@ -591,7 +591,7 @@ class FeetechMotorsBus:
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
# Subtract the homing offsets to come back to actual motor range of values
# Substract the homing offsets to come back to actual motor range of values
# which can be arbitrary.
values[i] -= homing_offset
@@ -632,7 +632,7 @@ class FeetechMotorsBus:
track["prev"][idx] = values[i]
continue
# Detect a full rotation occurred
# Detect a full rotation occured
if abs(track["prev"][idx] - values[i]) > 2048:
# Position went below 0 and got reset to 4095
if track["prev"][idx] < values[i]:

View File

@@ -87,7 +87,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
@@ -115,7 +115,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
# Joints with linear motions (like gripper of Aloha) are experessed in nominal range of [0, 100]
calib_idx = arm.motor_names.index("gripper")
calib_mode[calib_idx] = CalibrationMode.LINEAR.name

View File

@@ -443,7 +443,7 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))

View File

@@ -44,7 +44,7 @@ class ManipulatorRobot:
# TODO(rcadene): Implement force feedback
"""This class allows to control any manipulator robot of various number of motors.
Non exhaustive list of robots:
Non exaustive list of robots:
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed
by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
@@ -55,7 +55,7 @@ class ManipulatorRobot:
robot = ManipulatorRobot(KochRobotConfig())
```
Example of overwriting motors during instantiation:
Example of overwritting motors during instantiation:
```python
# Defines how to communicate with the motors of the leader and follower arms
leader_arms = {
@@ -90,7 +90,7 @@ class ManipulatorRobot:
robot = ManipulatorRobot(robot_config)
```
Example of overwriting cameras during instantiation:
Example of overwritting cameras during instantiation:
```python
# Defines how to communicate with 2 cameras connected to the computer.
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
@@ -348,7 +348,7 @@ class ManipulatorRobot:
set_operating_mode_(self.follower_arms[name])
# Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
@@ -500,7 +500,7 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
@@ -540,7 +540,7 @@ class ManipulatorRobot:
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries and format to pytorch
# Populate output dictionnaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:

View File

@@ -108,7 +108,7 @@ class StretchRobot(StretchAPI):
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
@@ -153,7 +153,7 @@ class StretchRobot(StretchAPI):
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
# Populate output dictionnaries
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Callable
from lerobot.common.utils.utils import format_big_number
@@ -93,12 +93,14 @@ class MetricsTracker:
num_episodes: int,
metrics: dict[str, AverageMeter],
initial_step: int = 0,
accelerator: Callable = None,
):
self.__dict__.update({k: None for k in self.__keys__})
self._batch_size = batch_size
self._num_frames = num_frames
self._avg_samples_per_ep = num_frames / num_episodes
self.metrics = metrics
self.accelerator = accelerator
self.steps = initial_step
# A sample is an (observation,action) pair, where observation and action
@@ -128,7 +130,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
self.samples += self._batch_size
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames

View File

@@ -16,7 +16,7 @@
import random
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Generator
from typing import Any, Generator, Callable
import numpy as np
import torch
@@ -163,14 +163,16 @@ def set_rng_state(random_state_dict: dict[str, Any]):
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed) -> None:
def set_seed(seed, accelerator: Callable = None) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if accelerator:
from accelerate.utils import set_seed
set_seed(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:

View File

@@ -20,10 +20,10 @@ import platform
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable
import numpy as np
import torch
from typing import Any
def none_or_int(value):
if value == "None":
@@ -50,12 +50,12 @@ def auto_select_torch_device() -> torch.device:
return torch.device("cpu")
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
def get_safe_torch_device(try_device: str, log: bool = False, accelerator: Callable = None) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
device = accelerator.device if accelerator else torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
@@ -103,7 +103,7 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def init_logging():
def init_logging(accelerator: Callable = None):
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
@@ -120,7 +120,10 @@ def init_logging():
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -216,3 +219,18 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def get_accelerate_config(accelerator: Callable = None) -> dict[str, Any]:
config = {}
if not accelerator:
return config
config["num_processes"] = accelerator.num_processes
config["device"] = str(accelerator.device)
config["distributed_type"] = str(accelerator.distributed_type)
config["mixed_precision"] = accelerator.mixed_precision
config["gradient_accumulation_steps"] = accelerator.gradient_accumulation_steps
return config

View File

@@ -27,7 +27,7 @@ class DatasetConfig:
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
# datsets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None

View File

@@ -102,7 +102,7 @@ class TrainPipelineConfig(HubMixin):
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Output directory {self.output_dir} alreay exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
)
elif not self.output_dir:

View File

@@ -59,8 +59,8 @@ python lerobot/scripts/control_sim_robot.py record \
```
**NOTE**: You can use your keyboard to control data recording flow.
- Tap right arrow key '->' to early exit while recording an episode and go to resetting the environment.
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
- Tap right arrow key '->' to early exit while recording an episode and go to reseting the environment.
- Tap right arrow key '->' to early exit while reseting the environment and got to recording the next episode.
- Tap left arrow key '<-' to early exit and re-record the current episode.
- Tap escape key 'esc' to stop the data recording.
This might require a sudo permission to allow your terminal to monitor keyboard events.
@@ -131,7 +131,7 @@ def none_or_int(value):
def init_sim_calibration(robot, cfg):
# Constants necessary for transforming the joint pos of the real robot to the sim
# depending on the robot description used in that sim.
# depending on the robot discription used in that sim.
start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
@@ -445,7 +445,7 @@ if __name__ == "__main__":
type=int,
default=0,
help=(
"Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; "
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."

View File

@@ -175,7 +175,7 @@ def push_dataset_to_hub(
# Robustify when `local_dir` is str instead of Path
local_dir = Path(local_dir)
# Send warning if local_dir isn't well formatted
# Send warning if local_dir isn't well formated
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
warnings.warn(
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",

View File

@@ -17,7 +17,7 @@ import logging
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any
from typing import Any, Callable
import torch
from termcolor import colored
@@ -46,6 +46,8 @@ from lerobot.common.utils.utils import (
get_safe_torch_device,
has_method,
init_logging,
get_accelerate_config,
is_launched_with_accelerate
)
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
@@ -63,30 +65,41 @@ def update_policy(
lr_scheduler=None,
use_amp: bool = False,
lock=None,
accelerator: Callable = None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
if accelerator:
accelerator.backward(loss)
accelerator.unscale_gradients(optimizer=optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
else:
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
@@ -94,9 +107,13 @@ def update_policy(
if lr_scheduler is not None:
lr_scheduler.step()
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
if accelerator:
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): # FIXME(mshukor): avoid accelerator.unwrap_model ?
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
@@ -106,10 +123,14 @@ def update_policy(
@parser.wrap()
def train(cfg: TrainPipelineConfig):
def train(cfg: TrainPipelineConfig, accelerator: Callable = None):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
@@ -117,10 +138,10 @@ def train(cfg: TrainPipelineConfig):
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_seed(cfg.seed)
set_seed(cfg.seed, accelerator=accelerator)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -141,7 +162,7 @@ def train(cfg: TrainPipelineConfig):
device=device,
ds_meta=dataset.meta,
)
policy.to(device)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
@@ -184,6 +205,10 @@ def train(cfg: TrainPipelineConfig):
pin_memory=device.type != "cpu",
drop_last=False,
)
if accelerator:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
@@ -197,7 +222,7 @@ def train(cfg: TrainPipelineConfig):
}
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step, accelerator=accelerator
)
logging.info("Start offline training on a fixed dataset")
@@ -219,6 +244,7 @@ def train(cfg: TrainPipelineConfig):
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
accelerator=accelerator,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -238,21 +264,26 @@ def train(cfg: TrainPipelineConfig):
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
if cfg.save_checkpoint and is_saving_step and (not accelerator or accelerator.is_main_process):
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
save_checkpoint(checkpoint_dir, step, cfg, policy if not accelerator else accelerator.unwrap_model(policy), optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
if accelerator:
accelerator.wait_for_everyone()
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.use_amp and not accelerator else nullcontext(),
):
eval_info = eval_policy(
eval_env,
policy,
policy if not accelerator else accelerator.unwrap_model(policy),
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
@@ -265,7 +296,7 @@ def train(cfg: TrainPipelineConfig):
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step, accelerator=None
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
@@ -283,4 +314,11 @@ def train(cfg: TrainPipelineConfig):
if __name__ == "__main__":
init_logging()
train()
if is_launched_with_accelerate():
import accelerate
# We set step_scheduler_with_optimizer False to prevent accelerate from
# adjusting the lr_scheduler steps based on the num_processes
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
train(accelerator=accelerator)
else:
train()

View File

@@ -245,17 +245,16 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
header += [f"{column_name}_{i}" for i in range(dim_state)]
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
column_names = [f"{column_name}_{i}" for i in range(dim_state)]
column_names = [f"motor_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
header += column_names
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
@@ -365,7 +364,7 @@ def visualize_dataset_html(
template_folder=template_dir,
)
else:
# Create a simlink from the dataset video folder containing mp4 files to the output directory
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"

View File

@@ -14,7 +14,21 @@
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
<body class="flex flex-col md:flex-row h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()">
<body class="flex flex-col md:flex-row h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
const { keyCode, key } = e;
if (keyCode === 32 || key === ' ') {
e.preventDefault();
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
const lowestEpisodeId = {{ episodes }}.at(0);
const highestEpisodeId = {{ episodes }}.at(-1);
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
window.location.href = `./episode_${nextEpisodeId}`;
}
}
}">
<!-- Sidebar -->
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
@@ -38,55 +52,25 @@
<p>Episodes:</p>
<!-- episodes menu for medium & large screens -->
<div class="ml-2 hidden md:block" x-data="episodePagination">
<ul>
<template x-for="episode in paginatedEpisodes" :key="episode">
<li class="font-mono text-sm mt-0.5">
<a :href="'episode_' + episode"
:class="{'underline': true, 'font-bold -ml-1': episode == {{ episode_id }}}"
x-text="'Episode ' + episode"></a>
</li>
</template>
</ul>
<div class="flex items-center mt-3 text-xs" x-show="totalPages > 1">
<button @click="prevPage()"
class="px-2 py-1 bg-slate-800 rounded mr-2"
:class="{'opacity-50 cursor-not-allowed': page === 1}"
:disabled="page === 1">
&laquo; Prev
</button>
<span class="font-mono mr-2" x-text="` ${page} / ${totalPages}`"></span>
<button @click="nextPage()"
class="px-2 py-1 bg-slate-800 rounded"
:class="{'opacity-50 cursor-not-allowed': page === totalPages}"
:disabled="page === totalPages">
Next &raquo;
</button>
</div>
</div>
<ul class="ml-2 hidden md:block">
{% for episode in episodes %}
<li class="font-mono text-sm mt-0.5">
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
Episode {{ episode }}
</a>
</li>
{% endfor %}
</ul>
<!-- episodes menu for small screens -->
<div class="flex overflow-x-auto md:hidden" x-data="episodePagination">
<button @click="prevPage()"
class="px-2 bg-slate-800 rounded mr-2"
:class="{'opacity-50 cursor-not-allowed': page === 1}"
:disabled="page === 1">&laquo;</button>
<div class="flex">
<template x-for="(episode, index) in paginatedEpisodes" :key="episode">
<p class="font-mono text-sm mt-0.5 px-2"
:class="{
'font-bold': episode == {{ episode_id }},
'border-r': index !== paginatedEpisodes.length - 1
}">
<a :href="'episode_' + episode" x-text="episode"></a>
</p>
</template>
</div>
<button @click="nextPage()"
class="px-2 bg-slate-800 rounded ml-2"
:class="{'opacity-50 cursor-not-allowed': page === totalPages}"
:disabled="page === totalPages">&raquo; </button>
<div class="flex overflow-x-auto md:hidden">
{% for episode in episodes %}
<p class="font-mono text-sm mt-0.5 border-r last:border-r-0 px-2 {% if episode_id == episode %}font-bold{% endif %}">
<a href="episode_{{ episode }}" class="">
{{ episode }}
</a>
</p>
{% endfor %}
</div>
</div>
@@ -246,16 +230,14 @@
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
<p x-text="`${rowLabels[rowIndex]}`"></p>
</div>
</td>
<template x-for="(cell, colIndex) in row">
<td x-show="cell" class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2" :class="{ 'hidden': cell.isNull }">
<div class="flex gap-x-2">
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
<span x-text="`${!cell.isNull ? cell.label : null}`"></span>
</div>
<span class="w-14 text-right" x-text="`${!cell.isNull ? (typeof cell.value === 'number' ? cell.value.toFixed(2) : cell.value) : null}`"
<div class="flex gap-x-2 w-24 justify-between px-2" :class="{ 'hidden': cell.isNull }">
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
<span x-text="`${!cell.isNull ? cell.value.toFixed(2) : null}`"
:style="`color: ${cell.color}`"></span>
</div>
</td>
@@ -297,6 +279,7 @@
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
videosKeysSelected: [],
columns: {{ columns | tojson }},
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
// alpine initialization
init() {
@@ -469,68 +452,6 @@
}
};
}
document.addEventListener('alpine:init', () => {
// Episode pagination component
Alpine.data('episodePagination', () => ({
episodes: {{ episodes }},
pageSize: 100,
page: 1,
init() {
// Find which page contains the current episode_id
const currentEpisodeId = {{ episode_id }};
const episodeIndex = this.episodes.indexOf(currentEpisodeId);
if (episodeIndex !== -1) {
this.page = Math.floor(episodeIndex / this.pageSize) + 1;
}
},
get totalPages() {
return Math.ceil(this.episodes.length / this.pageSize);
},
get paginatedEpisodes() {
const start = (this.page - 1) * this.pageSize;
const end = start + this.pageSize;
return this.episodes.slice(start, end);
},
nextPage() {
if (this.page < this.totalPages) {
this.page++;
}
},
prevPage() {
if (this.page > 1) {
this.page--;
}
}
}));
});
</script>
<script>
window.addEventListener('keydown', (e) => {
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
const { keyCode, key } = e;
if (keyCode === 32 || key === ' ') {
e.preventDefault();
const btnPause = document.querySelector('[x-ref="btnPause"]');
const btnPlay = document.querySelector('[x-ref="btnPlay"]');
btnPause.classList.contains('hidden') ? btnPlay.click() : btnPause.click();
} else if (key === 'ArrowDown' || key === 'ArrowUp') {
const episodes = {{ episodes }}; // Access episodes directly from the Jinja template
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
const lowestEpisodeId = episodes.at(0);
const highestEpisodeId = episodes.at(-1);
if (nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId) {
window.location.href = `./episode_${nextEpisodeId}`;
}
}
});
</script>
</body>

View File

@@ -116,19 +116,6 @@ exclude = [
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
[tool.typos]
default.extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on" # spellchecker:<on|off>
]
default.extend-ignore-identifiers-re = [
# Add individual words here to ignore them
"2nd",
"pn",
"ser",
"ein",
]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eb7b74f919adf8d4478585f65c54997e6f3bccab67eadb4048300108586a4163
size 5104

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dfbc3b1ad5e3b94311edda0f04db002b26117b0719b73dfdb56dd483dc9c409d
size 31672

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e39afdf1f3db8a72a1095a5a0ffdb7e67f478a28bd73e59cda197687da8d236c
size 68

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5dd39a554c9c3db537e98c9ceade024d172c46c4fa7ce9e27601b94116445417
size 33400

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a5ec46abc5a3c85675a5ee4a1bb362eecb3ff4c546082ff309c89fc7821f38bd
size 515400

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50303d05caea725c4a240f1389424d6c2361961f2cee729a0010e909ebffed81
size 31672

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9bb9b195d32e05550af0edd5df88fcc761c829ab8c4b129ba970a723f39b46ee
size 68

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:683a2038185f3d070e7d7c0c31e4aa75067c11bf798daa41c9fab336f4183fda
size 33400

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
size 5104

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
size 31672

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
size 33400

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
size 515400

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
size 31672

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
size 33400

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e56a5d30778395534a06ad1742843700424614168fc26d1098558012a5df90c6
size 5104

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c9007dd51c748db4ecd6d75e70bdcabf8c312454ac97bf6710895a12e7288557
size 31672

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:170bd8365dfd1e36e8f56814bf8bc2057aa0d035c41212b7ddd7e4b9feee1633
size 68

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:11884346b41ca102c672bb0f361ea9699d2f8b33bb503038b53cc7e7fafd281b
size 34920

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0c259ea9c40aab3841ca35b2a2e708d8829b0a9163b2f9e5efd28f1c65848293
size 4600

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:77cd4127a45ded2f75d85ca9c17537808517614ef16fb3035cebb1b45547acbf
size 47424

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcff4b736e95d685d56830b501f4542b081f4334f72d28a7415809f4d9d15d0f
size 68

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:60775e91ed550aae66cb0547ee4b0e38917f29172e942671e9361b3812364df6
size 49120

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
size 992

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
size 47424

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
size 49120

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
size 200

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
size 200

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a
size 472

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1
size 472

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@@ -18,11 +18,11 @@ def _generate_image(width: int, height: int):
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
def cvtColor(color_image, color_conversion): # noqa: N802
if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
def cvtColor(color_image, color_convertion): # noqa: N802
if color_convertion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
return color_image[:, :, [2, 1, 0]]
else:
raise NotImplementedError(color_conversion)
raise NotImplementedError(color_convertion)
def rotate(color_image, rotation):

View File

@@ -27,13 +27,16 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
# TODO(rcadene, aliberts): env_name?
set_seed(1337)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
**train_kwargs,
)
train_cfg.validate() # Needed for auto-setting some parameters
@@ -51,11 +54,8 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch)
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
loss.backward()
grad_stats = {}
@@ -101,27 +101,30 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
print(f" - Revert with: `git checkout -- {output_dir}`")
shutil.rmtree(output_dir)
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors")
save_file(actions, output_dir / "actions.safetensors")
if env_policy_dir.exists():
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
print(f" - Validate with: `git add {env_policy_dir}`")
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
save_file(actions, env_policy_dir / "actions.safetensors")
if __name__ == "__main__":
artifacts_cfg = [
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
env_policies = [
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
(
"lerobot/pusht",
"pusht",
"diffusion",
{
"n_action_steps": 8,
@@ -130,17 +133,18 @@ if __name__ == "__main__":
},
"",
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"1000_steps",
"_1000_steps",
),
]
if len(artifacts_cfg) == 0:
if len(env_policies) == 0:
raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
)

View File

@@ -27,7 +27,7 @@ import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
# Maximum absolute difference between two consecutive images recorded by a camera.
# Maximum absolute difference between two consecutive images recored by a camera.
# This value differs with respect to the camera.
MAX_PIXEL_DIFFERENCE = 25

View File

@@ -179,7 +179,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
policy.save_pretrained(pretrained_policy_path)
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
# during inference, to reach constant fps, so we test this here.
# during inference, to reach constent fps, so we test this here.
if robot_type == "aloha":
num_image_writer_processes = 1

View File

@@ -486,7 +486,7 @@ def test_backward_compatibility(repo_id):
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
# ignore language instructions (if exists) in language conditioned datasets
# TODO (michel-aractingi): transform language obs to language embeddings via tokenizer
# TODO (michel-aractingi): transform language obs to langauge embeddings via tokenizer
new_frame.pop("language_instruction", None)
old_frame.pop("language_instruction", None)
new_frame.pop("task", None)

View File

@@ -32,10 +32,10 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lengths = list(accumulate(episode_lengths))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
"to": np.array(cumulative_lengths, dtype=np.int64),
"from": np.array([0] + cumulative_lenghts[:-1], dtype=np.int64),
"to": np.array(cumulative_lenghts, dtype=np.int64),
}

View File

@@ -88,7 +88,7 @@ def test_motors_bus(request, motor_type, mock):
motors_bus = make_motors_bus(motor_type, mock=mock)
# Test reading and writing before connecting raises an error
# Test reading and writting before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
motors_bus.read("Torque_Enable")
with pytest.raises(RobotDeviceNotConnectedError):

View File

@@ -166,7 +166,7 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@@ -236,7 +236,7 @@ def test_compute_sampler_weights_trivial(
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights /= expected_weights.sum()
torch.testing.assert_close(weights, expected_weights)
assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
@@ -248,7 +248,7 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
torch.testing.assert_close(
assert torch.allclose(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
)
@@ -261,7 +261,7 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
)
torch.testing.assert_close(
assert torch.allclose(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
)
@@ -279,4 +279,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))

Some files were not shown because too many files have changed in this diff Show More