forked from tangger/lerobot
Compare commits
24 Commits
feat/autop
...
fix/lint_w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e511e7eda5 | ||
|
|
32fffd4bbb | ||
|
|
03c7cf8a63 | ||
|
|
f5ed3723f0 | ||
|
|
b104be0d04 | ||
|
|
f9e4a1f5c4 | ||
|
|
0eb56cec14 | ||
|
|
e59ef036e1 | ||
|
|
9b380eaf67 | ||
|
|
1187604ba0 | ||
|
|
5c6f2d2cd0 | ||
|
|
652fedf69c | ||
|
|
85214ec303 | ||
|
|
dffa5a18db | ||
|
|
301f152a34 | ||
|
|
0ed08c0b1f | ||
|
|
254bc707e7 | ||
|
|
074f0ac8fe | ||
|
|
25c63ccf63 | ||
|
|
5e9473806c | ||
|
|
10706ed753 | ||
|
|
0b8205a8a0 | ||
|
|
57ae509823 | ||
|
|
5d24ce3160 |
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Misc
|
# Misc
|
||||||
.git
|
.git
|
||||||
tmp
|
tmp
|
||||||
|
|||||||
14
.gitattributes
vendored
14
.gitattributes
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|||||||
14
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
14
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: "\U0001F41B Bug Report"
|
name: "\U0001F41B Bug Report"
|
||||||
description: Submit a bug report to help us improve LeRobot
|
description: Submit a bug report to help us improve LeRobot
|
||||||
body:
|
body:
|
||||||
|
|||||||
14
.github/workflows/build-docker-images.yml
vendored
14
.github/workflows/build-docker-images.yml
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
||||||
name: Builds
|
name: Builds
|
||||||
|
|||||||
14
.github/workflows/nightly-tests.yml
vendored
14
.github/workflows/nightly-tests.yml
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
||||||
name: Nightly
|
name: Nightly
|
||||||
|
|||||||
161
.github/workflows/pr_style_bot.yml
vendored
161
.github/workflows/pr_style_bot.yml
vendored
@@ -1,161 +0,0 @@
|
|||||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/.github/workflows/pr_style_bot.yml
|
|
||||||
name: PR Style Bot
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
|
|
||||||
permissions: {}
|
|
||||||
|
|
||||||
env:
|
|
||||||
PYTHON_VERSION: "3.10"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-permissions:
|
|
||||||
if: >
|
|
||||||
contains(github.event.comment.body, '@bot /style') &&
|
|
||||||
github.event.issue.pull_request != null
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
is_authorized: ${{ steps.check_user_permission.outputs.has_permission }}
|
|
||||||
steps:
|
|
||||||
- name: Check user permission
|
|
||||||
id: check_user_permission
|
|
||||||
uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const comment_user = context.payload.comment.user.login;
|
|
||||||
const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
username: comment_user
|
|
||||||
});
|
|
||||||
|
|
||||||
const authorized =
|
|
||||||
permission.permission === 'admin' ||
|
|
||||||
permission.permission === 'write';
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`User ${comment_user} has permission level: ${permission.permission}, ` +
|
|
||||||
`authorized: ${authorized} (admins & maintainers allowed)`
|
|
||||||
);
|
|
||||||
|
|
||||||
core.setOutput('has_permission', authorized);
|
|
||||||
|
|
||||||
run-style-bot:
|
|
||||||
needs: check-permissions
|
|
||||||
if: needs.check-permissions.outputs.is_authorized == 'true'
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
pull-requests: write
|
|
||||||
steps:
|
|
||||||
- name: Extract PR details
|
|
||||||
id: pr_info
|
|
||||||
uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const prNumber = context.payload.issue.number;
|
|
||||||
const { data: pr } = await github.rest.pulls.get({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
pull_number: prNumber
|
|
||||||
});
|
|
||||||
|
|
||||||
// We capture both the branch ref and the "full_name" of the head repo
|
|
||||||
// so that we can check out the correct repository & branch (including forks).
|
|
||||||
core.setOutput("prNumber", prNumber);
|
|
||||||
core.setOutput("headRef", pr.head.ref);
|
|
||||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
|
||||||
|
|
||||||
- name: Check out PR branch
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
env:
|
|
||||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
|
||||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
|
||||||
with:
|
|
||||||
persist-credentials: true
|
|
||||||
# Instead of checking out the base repo, use the contributor's repo name
|
|
||||||
repository: ${{ env.HEADREPOFULLNAME }}
|
|
||||||
ref: ${{ env.HEADREF }}
|
|
||||||
# You may need fetch-depth: 0 for being able to push
|
|
||||||
fetch-depth: 0
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Debug
|
|
||||||
env:
|
|
||||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
|
||||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
|
||||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
|
||||||
run: |
|
|
||||||
echo "PR number: ${PRNUMBER}"
|
|
||||||
echo "Head Ref: ${HEADREF}"
|
|
||||||
echo "Head Repo Full Name: ${HEADREPOFULLNAME}"
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
|
||||||
|
|
||||||
- name: Get Ruff Version from pre-commit-config.yaml
|
|
||||||
id: get-ruff-version
|
|
||||||
run: |
|
|
||||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
|
||||||
echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Install Ruff
|
|
||||||
env:
|
|
||||||
RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
|
|
||||||
run: python -m pip install "ruff==${RUFF_VERSION}"
|
|
||||||
|
|
||||||
- name: Ruff check
|
|
||||||
run: ruff check --fix
|
|
||||||
|
|
||||||
- name: Ruff format
|
|
||||||
run: ruff format
|
|
||||||
|
|
||||||
- name: Commit and push changes
|
|
||||||
id: commit_and_push
|
|
||||||
env:
|
|
||||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
|
||||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
|
||||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
run: |
|
|
||||||
echo "HEADREPOFULLNAME: ${HEADREPOFULLNAME}, HEADREF: ${HEADREF}"
|
|
||||||
# Configure git with the Actions bot user
|
|
||||||
git config user.name "github-actions[bot]"
|
|
||||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
|
||||||
git config --local lfs.https://github.com/.locksverify false
|
|
||||||
|
|
||||||
# Make sure your 'origin' remote is set to the contributor's fork
|
|
||||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
|
|
||||||
|
|
||||||
# If there are changes after running style/quality, commit them
|
|
||||||
if [ -n "$(git status --porcelain)" ]; then
|
|
||||||
git add .
|
|
||||||
git commit -m "Apply style fixes"
|
|
||||||
# Push to the original contributor's forked branch
|
|
||||||
git push origin HEAD:${HEADREF}
|
|
||||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
|
||||||
else
|
|
||||||
echo "No changes to commit."
|
|
||||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Comment on PR with workflow run link
|
|
||||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
|
||||||
uses: actions/github-script@v6
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const prNumber = parseInt(process.env.prNumber, 10);
|
|
||||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
|
||||||
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: prNumber,
|
|
||||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
|
||||||
});
|
|
||||||
env:
|
|
||||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
|
||||||
14
.github/workflows/quality.yml
vendored
14
.github/workflows/quality.yml
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: Quality
|
name: Quality
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
|||||||
14
.github/workflows/test-docker-build.yml
vendored
14
.github/workflows/test-docker-build.yml
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
||||||
name: Test Dockerfiles
|
name: Test Dockerfiles
|
||||||
|
|||||||
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
name: Tests
|
name: Tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
|||||||
14
.github/workflows/trufflehog.yml
vendored
14
.github/workflows/trufflehog.yml
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
|
||||||
|
|||||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logs
|
logs
|
||||||
tmp
|
tmp
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
exclude: ^(tests/data)
|
exclude: ^(tests/data)
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.10
|
||||||
|
|||||||
32
Makefile
32
Makefile
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
.PHONY: tests
|
.PHONY: tests
|
||||||
|
|
||||||
PYTHON_PATH := $(shell which python)
|
PYTHON_PATH := $(shell which python)
|
||||||
@@ -33,6 +47,7 @@ test-act-ete-train:
|
|||||||
--policy.dim_model=64 \
|
--policy.dim_model=64 \
|
||||||
--policy.n_action_steps=20 \
|
--policy.n_action_steps=20 \
|
||||||
--policy.chunk_size=20 \
|
--policy.chunk_size=20 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||||
@@ -47,7 +62,6 @@ test-act-ete-train:
|
|||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/act/
|
--output_dir=tests/outputs/act/
|
||||||
|
|
||||||
test-act-ete-train-resume:
|
test-act-ete-train-resume:
|
||||||
@@ -58,11 +72,11 @@ test-act-ete-train-resume:
|
|||||||
test-act-ete-eval:
|
test-act-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=aloha \
|
--env.type=aloha \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-diffusion-ete-train:
|
test-diffusion-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
@@ -70,6 +84,7 @@ test-diffusion-ete-train:
|
|||||||
--policy.down_dims='[64,128,256]' \
|
--policy.down_dims='[64,128,256]' \
|
||||||
--policy.diffusion_step_embed_dim=32 \
|
--policy.diffusion_step_embed_dim=32 \
|
||||||
--policy.num_inference_steps=10 \
|
--policy.num_inference_steps=10 \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/pusht \
|
--dataset.repo_id=lerobot/pusht \
|
||||||
@@ -84,21 +99,21 @@ test-diffusion-ete-train:
|
|||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/diffusion/
|
--output_dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
test-diffusion-ete-eval:
|
test-diffusion-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=pusht \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|
||||||
test-tdmpc-ete-train:
|
test-tdmpc-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
--policy.type=tdmpc \
|
--policy.type=tdmpc \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
@@ -114,15 +129,14 @@ test-tdmpc-ete-train:
|
|||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
--log_freq=1 \
|
--log_freq=1 \
|
||||||
--wandb.enable=false \
|
--wandb.enable=false \
|
||||||
--device=$(DEVICE) \
|
|
||||||
--output_dir=tests/outputs/tdmpc/
|
--output_dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
test-tdmpc-ete-eval:
|
test-tdmpc-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||||
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=xarm \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--env.task=XarmLift-v0 \
|
--env.task=XarmLift-v0 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1
|
||||||
--device=$(DEVICE)
|
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -92,20 +92,15 @@ git clone https://github.com/huggingface/lerobot.git
|
|||||||
cd lerobot
|
cd lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
Create a virtual environment with Python 3.10 and activate it using [`uv`](https://github.com/astral-sh/uv):
|
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||||
```bash
|
```bash
|
||||||
# Install uv if you haven't already
|
conda create -y -n lerobot python=3.10
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
conda activate lerobot
|
||||||
|
|
||||||
# Create and activate virtual environment with Python 3.10
|
|
||||||
uv venv .venv --python=3.10
|
|
||||||
source .venv/bin/activate # On Unix/macOS
|
|
||||||
# .venv\Scripts\activate # On Windows
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Install 🤗 LeRobot:
|
Install 🤗 LeRobot:
|
||||||
```bash
|
```bash
|
||||||
uv pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
||||||
@@ -389,3 +384,6 @@ Additionally, if you are using any of the particular policy architecture, pretra
|
|||||||
year={2024}
|
year={2024}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ def parse_int_or_none(value) -> int | None:
|
|||||||
def check_datasets_formats(repo_ids: list) -> None:
|
def check_datasets_formats(repo_ids: list) -> None:
|
||||||
for repo_id in repo_ids:
|
for repo_id in repo_ids:
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id)
|
||||||
|
# TODO(Steven): Seems this API has changed
|
||||||
if dataset.video:
|
if dataset.video:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||||
@@ -30,7 +44,7 @@ pretrained_policy_path = "lerobot/diffusion_pusht"
|
|||||||
# OR a path to a local outputs/train folder.
|
# OR a path to a local outputs/train folder.
|
||||||
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
|
||||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
|
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||||
|
|
||||||
# Initialize evaluation environment to render two observation types:
|
# Initialize evaluation environment to render two observation types:
|
||||||
# an image of the scene and state/position of the agent. The environment
|
# an image of the scene and state/position of the agent. The environment
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||||
|
|
||||||
Once you have trained a model with this script, you can try to evaluate it on
|
Once you have trained a model with this script, you can try to evaluate it on
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
|
||||||
|
|
||||||
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -208,7 +222,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||||
repo_id = "lerobot/pusht"
|
repository_id = "lerobot/pusht"
|
||||||
|
|
||||||
modes = ["video", "image", "keypoints"]
|
modes = ["video", "image", "keypoints"]
|
||||||
# Uncomment if you want to try with a specific mode
|
# Uncomment if you want to try with a specific mode
|
||||||
@@ -216,13 +230,13 @@ if __name__ == "__main__":
|
|||||||
# modes = ["image"]
|
# modes = ["image"]
|
||||||
# modes = ["keypoints"]
|
# modes = ["keypoints"]
|
||||||
|
|
||||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
data_dir = Path("data/lerobot-raw/pusht_raw")
|
||||||
for mode in modes:
|
for available_mode in modes:
|
||||||
if mode in ["image", "keypoints"]:
|
if available_mode in ["image", "keypoints"]:
|
||||||
repo_id += f"_{mode}"
|
repository_id += f"_{available_mode}"
|
||||||
|
|
||||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
main(data_dir, repo_id=repository_id, mode=available_mode)
|
||||||
|
|
||||||
# Uncomment if you want to load the local dataset and explore it
|
# Uncomment if you want to load the local dataset and explore it
|
||||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# keys
|
# keys
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import packaging.version
|
import packaging.version
|
||||||
|
|
||||||
V2_MESSAGE = """
|
V2_MESSAGE = """
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
dataset = MultiLeRobotDataset(
|
# dataset = MultiLeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
# cfg.dataset.repo_id,
|
||||||
# TODO(aliberts): add proper support for multi dataset
|
# # TODO(aliberts): add proper support for multi dataset
|
||||||
# delta_timestamps=delta_timestamps,
|
# # delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
# image_transforms=image_transforms,
|
||||||
video_backend=cfg.dataset.video_backend,
|
# video_backend=cfg.dataset.video_backend,
|
||||||
)
|
# )
|
||||||
logging.info(
|
# logging.info(
|
||||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
# "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
# f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||||
)
|
# )
|
||||||
|
|
||||||
if cfg.dataset.use_imagenet_stats:
|
if cfg.dataset.use_imagenet_stats:
|
||||||
for key in dataset.meta.camera_keys:
|
for key in dataset.meta.camera_keys:
|
||||||
|
|||||||
@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
|||||||
print(f"Error writing image {fpath}: {e}")
|
print(f"Error writing image {fpath}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def worker_thread_loop(queue: queue.Queue):
|
def worker_thread_loop(task_queue: queue.Queue):
|
||||||
while True:
|
while True:
|
||||||
item = queue.get()
|
item = task_queue.get()
|
||||||
if item is None:
|
if item is None:
|
||||||
queue.task_done()
|
task_queue.task_done()
|
||||||
break
|
break
|
||||||
image_array, fpath = item
|
image_array, fpath = item
|
||||||
write_image(image_array, fpath)
|
write_image(image_array, fpath)
|
||||||
queue.task_done()
|
task_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
def worker_process(queue: queue.Queue, num_threads: int):
|
def worker_process(task_queue: queue.Queue, num_threads: int):
|
||||||
threads = []
|
threads = []
|
||||||
for _ in range(num_threads):
|
for _ in range(num_threads):
|
||||||
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class LeRobotDatasetMetadata:
|
|||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||||
|
self.stats = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if force_cache_sync:
|
if force_cache_sync:
|
||||||
@@ -102,10 +103,10 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
def load_metadata(self):
|
def load_metadata(self):
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
|
||||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
if self._version < packaging.version.parse("v2.1"):
|
if self.version < packaging.version.parse("v2.1"):
|
||||||
self.stats = load_stats(self.root)
|
self.stats = load_stats(self.root)
|
||||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||||
else:
|
else:
|
||||||
@@ -127,7 +128,7 @@ class LeRobotDatasetMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _version(self) -> packaging.version.Version:
|
def version(self) -> packaging.version.Version:
|
||||||
"""Codebase version used to create this dataset."""
|
"""Codebase version used to create this dataset."""
|
||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info["codebase_version"])
|
||||||
|
|
||||||
@@ -321,8 +322,9 @@ class LeRobotDatasetMetadata:
|
|||||||
robot_type = robot.robot_type
|
robot_type = robot.robot_type
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
"Some cameras in your %s robot don't have an fps matching the fps of your dataset."
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
|
||||||
|
robot.robot_type,
|
||||||
)
|
)
|
||||||
elif features is None:
|
elif features is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -486,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.meta = LeRobotDatasetMetadata(
|
self.meta = LeRobotDatasetMetadata(
|
||||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||||
)
|
)
|
||||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
|
||||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||||
self.stats = aggregate_stats(episodes_stats)
|
self.stats = aggregate_stats(episodes_stats)
|
||||||
|
|
||||||
@@ -518,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self,
|
self,
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
license: str | None = "apache-2.0",
|
dataset_license: str | None = "apache-2.0",
|
||||||
tag_version: bool = True,
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
@@ -561,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||||
card = create_lerobot_dataset_card(
|
card = create_lerobot_dataset_card(
|
||||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
@@ -842,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
|
episode_buffer = None
|
||||||
if not episode_data:
|
if not episode_data:
|
||||||
episode_buffer = self.episode_buffer
|
episode_buffer = self.episode_buffer
|
||||||
|
|
||||||
@@ -1086,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
extra_keys = set(ds.features).difference(intersection_features)
|
extra_keys = set(ds.features).difference(intersection_features)
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
"keys %s of %s were disabled as they are not contained in all the other datasets.",
|
||||||
"other datasets."
|
extra_keys,
|
||||||
|
repo_id,
|
||||||
)
|
)
|
||||||
self.disabled_features.update(extra_keys)
|
self.disabled_features.update(extra_keys)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
|
|||||||
# rechunk recompress
|
# rechunk recompress
|
||||||
group.move(name, tmp_key)
|
group.move(name, tmp_key)
|
||||||
old_arr = group[tmp_key]
|
old_arr = group[tmp_key]
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||||
source=old_arr,
|
source=old_arr,
|
||||||
dest=group,
|
dest=group,
|
||||||
name=name,
|
name=name,
|
||||||
@@ -192,7 +192,7 @@ class ReplayBuffer:
|
|||||||
else:
|
else:
|
||||||
root = zarr.group(store=store)
|
root = zarr.group(store=store)
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||||
)
|
)
|
||||||
data_group = root.create_group("data", overwrite=True)
|
data_group = root.create_group("data", overwrite=True)
|
||||||
@@ -205,7 +205,7 @@ class ReplayBuffer:
|
|||||||
if cks == value.chunks and cpr == value.compressor:
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
this_path = "/data/" + key
|
this_path = "/data/" + key
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=src_store,
|
source=src_store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path=this_path,
|
source_path=this_path,
|
||||||
@@ -214,7 +214,7 @@ class ReplayBuffer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# copy with recompression
|
# copy with recompression
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||||
source=value,
|
source=value,
|
||||||
dest=data_group,
|
dest=data_group,
|
||||||
name=key,
|
name=key,
|
||||||
@@ -275,7 +275,7 @@ class ReplayBuffer:
|
|||||||
compressors = {}
|
compressors = {}
|
||||||
if self.backend == "zarr":
|
if self.backend == "zarr":
|
||||||
# recompression free copy
|
# recompression free copy
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=self.root.store,
|
source=self.root.store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path="/meta",
|
source_path="/meta",
|
||||||
@@ -297,7 +297,7 @@ class ReplayBuffer:
|
|||||||
if cks == value.chunks and cpr == value.compressor:
|
if cks == value.chunks and cpr == value.compressor:
|
||||||
# copy without recompression
|
# copy without recompression
|
||||||
this_path = "/data/" + key
|
this_path = "/data/" + key
|
||||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||||
source=self.root.store,
|
source=self.root.store,
|
||||||
dest=store,
|
dest=store,
|
||||||
source_path=this_path,
|
source_path=this_path,
|
||||||
|
|||||||
@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
|
|||||||
)
|
)
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
|
|||||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||||
else:
|
else:
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
_, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
|
||||||
|
|
||||||
@@ -103,6 +103,7 @@ def load_from_raw(
|
|||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
velocity = None
|
||||||
if "/observations/qvel" in ep:
|
if "/observations/qvel" in ep:
|
||||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||||
if "/observations/effort" in ep:
|
if "/observations/effort" in ep:
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
|
|||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 30
|
fps = 30
|
||||||
|
|
||||||
|
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
|
||||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
def load_from_raw(
|
||||||
|
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
|
||||||
|
):
|
||||||
# Load data stream that will be used as reference for the timestamps synchronization
|
# Load data stream that will be used as reference for the timestamps synchronization
|
||||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||||
if len(reference_files) == 0:
|
if len(reference_files) == 0:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
|||||||
|
|
||||||
num_images = len(imgs_array)
|
num_images = len(imgs_array)
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
_ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||||
|
|
||||||
|
|
||||||
def get_default_encoding() -> dict:
|
def get_default_encoding() -> dict:
|
||||||
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
|||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
current_episode = None
|
current_episode = None
|
||||||
"""
|
# The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
# For instance, the following is a valid episode_index:
|
||||||
For instance, the following is a valid episode_index:
|
# [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
#
|
||||||
|
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
# ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
# {
|
||||||
{
|
# "from": [0, 3, 7],
|
||||||
"from": [0, 3, 7],
|
# "to": [3, 7, 12]
|
||||||
"to": [3, 7, 12]
|
# }
|
||||||
}
|
|
||||||
"""
|
|
||||||
if len(hf_dataset) == 0:
|
if len(hf_dataset) == 0:
|
||||||
episode_data_index = {
|
episode_data_index = {
|
||||||
"from": torch.tensor([]),
|
"from": torch.tensor([]),
|
||||||
"to": torch.tensor([]),
|
"to": torch.tensor([]),
|
||||||
}
|
}
|
||||||
return episode_data_index
|
return episode_data_index
|
||||||
|
idx = None
|
||||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||||
if episode_idx != current_episode:
|
if episode_idx != current_episode:
|
||||||
# We encountered a new episode, so we append its starting location to the "from" list
|
# We encountered a new episode, so we append its starting location to the "from" list
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
|
|||||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing transform() implementation
|
||||||
class RandomSubsetApply(Transform):
|
class RandomSubsetApply(Transform):
|
||||||
"""Apply a random subset of N transformations from a list of transformations.
|
"""Apply a random subset of N transformations from a list of transformations.
|
||||||
|
|
||||||
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
|||||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing transform() implementation
|
||||||
class ImageTransforms(Transform):
|
class ImageTransforms(Transform):
|
||||||
"""A class to compose image transforms based on configuration."""
|
"""A class to compose image transforms based on configuration."""
|
||||||
|
|
||||||
|
|||||||
@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
|||||||
|
|
||||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
# Embed image bytes into the table before saving to parquet
|
# Embed image bytes into the table before saving to parquet
|
||||||
format = dataset.format
|
ds_format = dataset.format
|
||||||
dataset = dataset.with_format("arrow")
|
dataset = dataset.with_format("arrow")
|
||||||
dataset = dataset.map(embed_table_storage, batched=False)
|
dataset = dataset.map(embed_table_storage, batched=False)
|
||||||
dataset = dataset.with_format(**format)
|
dataset = dataset.with_format(**ds_format)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def load_json(fpath: Path) -> Any:
|
def load_json(fpath: Path) -> Any:
|
||||||
with open(fpath) as f:
|
with open(fpath, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def write_json(data: dict, fpath: Path) -> None:
|
def write_json(data: dict, fpath: Path) -> None:
|
||||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -300,7 +300,7 @@ def check_version_compatibility(
|
|||||||
if v_check.major < v_current.major and enforce_breaking_major:
|
if v_check.major < v_current.major and enforce_breaking_major:
|
||||||
raise BackwardCompatibilityError(repo_id, v_check)
|
raise BackwardCompatibilityError(repo_id, v_check)
|
||||||
elif v_check.minor < v_current.minor:
|
elif v_check.minor < v_current.minor:
|
||||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||||
|
|
||||||
|
|
||||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||||
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||||||
if compatibles:
|
if compatibles:
|
||||||
return_version = max(compatibles)
|
return_version = max(compatibles)
|
||||||
if return_version < target_version:
|
if return_version < target_version:
|
||||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
logging.warning(
|
||||||
|
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
|
||||||
|
)
|
||||||
return f"v{return_version}"
|
return f"v{return_version}"
|
||||||
|
|
||||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||||
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
shape = ft["shape"]
|
shape = ft["shape"]
|
||||||
if ft["dtype"] in ["image", "video"]:
|
if ft["dtype"] in ["image", "video"]:
|
||||||
type = FeatureType.VISUAL
|
feature_type = FeatureType.VISUAL
|
||||||
if len(shape) != 3:
|
if len(shape) != 3:
|
||||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||||
|
|
||||||
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
elif key == "observation.environment_state":
|
elif key == "observation.environment_state":
|
||||||
type = FeatureType.ENV
|
feature_type = FeatureType.ENV
|
||||||
elif key.startswith("observation"):
|
elif key.startswith("observation"):
|
||||||
type = FeatureType.STATE
|
feature_type = FeatureType.STATE
|
||||||
elif key == "action":
|
elif key == "action":
|
||||||
type = FeatureType.ACTION
|
feature_type = FeatureType.ACTION
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
policy_features[key] = PolicyFeature(
|
policy_features[key] = PolicyFeature(
|
||||||
type=type,
|
type=feature_type,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -871,11 +871,11 @@ def batch_convert():
|
|||||||
try:
|
try:
|
||||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||||
status = f"{repo_id}: success."
|
status = f"{repo_id}: success."
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
except Exception:
|
except Exception:
|
||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
|||||||
|
|
||||||
json_path = v2_dir / STATS_PATH
|
json_path = v2_dir / STATS_PATH
|
||||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(json_path, "w") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(serialized_stats, f, indent=4)
|
json.dump(serialized_stats, f, indent=4)
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
with open(json_path) as f:
|
with open(json_path, encoding="utf-8") as f:
|
||||||
stats_json = json.load(f)
|
stats_json = json.load(f)
|
||||||
|
|
||||||
stats_json = flatten_dict(stats_json)
|
stats_json = flatten_dict(stats_json)
|
||||||
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
|
|||||||
dtype = ft.dtype
|
dtype = ft.dtype
|
||||||
shape = (1,)
|
shape = (1,)
|
||||||
names = None
|
names = None
|
||||||
if isinstance(ft, datasets.Sequence):
|
elif isinstance(ft, datasets.Sequence):
|
||||||
assert isinstance(ft.feature, datasets.Value)
|
assert isinstance(ft.feature, datasets.Value)
|
||||||
dtype = ft.feature.dtype
|
dtype = ft.feature.dtype
|
||||||
shape = (ft.length,)
|
shape = (ft.length,)
|
||||||
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
|
|||||||
dtype = "video"
|
dtype = "video"
|
||||||
shape = None # Add shape later
|
shape = None # Add shape later
|
||||||
names = ["height", "width", "channels"]
|
names = ["height", "width", "channels"]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Feature type {ft._type} not supported.")
|
||||||
|
|
||||||
features[key] = {
|
features[key] = {
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
@@ -358,9 +360,9 @@ def move_videos(
|
|||||||
if len(video_dirs) == 1:
|
if len(video_dirs) == 1:
|
||||||
video_path = video_dirs[0] / video_file
|
video_path = video_dirs[0] / video_file
|
||||||
else:
|
else:
|
||||||
for dir in video_dirs:
|
for v_dir in video_dirs:
|
||||||
if (dir / video_file).is_file():
|
if (v_dir / video_file).is_file():
|
||||||
video_path = dir / video_file
|
video_path = v_dir / video_file
|
||||||
break
|
break
|
||||||
|
|
||||||
video_path.rename(work_dir / target_path)
|
video_path.rename(work_dir / target_path)
|
||||||
@@ -652,6 +654,7 @@ def main():
|
|||||||
if not args.local_dir:
|
if not args.local_dir:
|
||||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||||
|
|
||||||
|
robot_config = None
|
||||||
if args.robot is not None:
|
if args.robot is not None:
|
||||||
robot_config = make_robot_config(args.robot)
|
robot_config = make_robot_config(args.robot)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -36,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
|
|||||||
return f"{repo_id}: skipped (no diff)"
|
return f"{repo_id}: skipped (no diff)"
|
||||||
|
|
||||||
if diff_meta_parquet:
|
if diff_meta_parquet:
|
||||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
|
||||||
assert diff_meta_parquet == {"language_instruction"}
|
assert diff_meta_parquet == {"language_instruction"}
|
||||||
lerobot_metadata.features.pop("language_instruction")
|
lerobot_metadata.features.pop("language_instruction")
|
||||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||||
@@ -65,7 +79,7 @@ def batch_fix():
|
|||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
logging.info(status)
|
logging.info(status)
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def batch_convert():
|
|||||||
except Exception:
|
except Exception:
|
||||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||||
|
|
||||||
with open(logfile, "a") as file:
|
with open(logfile, "a", encoding="utf-8") as file:
|
||||||
file.write(status + "\n")
|
file.write(status + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||||
2.1. It will:
|
2.1. It will:
|
||||||
@@ -31,6 +45,9 @@ V21 = "v2.1"
|
|||||||
|
|
||||||
|
|
||||||
class SuppressWarnings:
|
class SuppressWarnings:
|
||||||
|
def __init__(self):
|
||||||
|
self.previous_level = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||||
logging.getLogger().setLevel(logging.ERROR)
|
logging.getLogger().setLevel(logging.ERROR)
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ def decode_video_frames_torchvision(
|
|||||||
for frame in reader:
|
for frame in reader:
|
||||||
current_ts = frame["pts"]
|
current_ts = frame["pts"]
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
logging.info("frame loaded at timestamp=%.4f", current_ts)
|
||||||
loaded_frames.append(frame["data"])
|
loaded_frames.append(frame["data"])
|
||||||
loaded_ts.append(current_ts)
|
loaded_ts.append(current_ts)
|
||||||
if current_ts >= last_ts:
|
if current_ts >= last_ts:
|
||||||
@@ -118,7 +118,7 @@ def decode_video_frames_torchvision(
|
|||||||
closest_ts = loaded_ts[argmin_]
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"{closest_ts=}")
|
logging.info("closest_ts=%s", closest_ts)
|
||||||
|
|
||||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||||
closest_frames = closest_frames.type(torch.float32) / 255
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
|||||||
"json",
|
"json",
|
||||||
str(video_path),
|
str(video_path),
|
||||||
]
|
]
|
||||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(
|
||||||
|
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
@@ -263,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
|||||||
"json",
|
"json",
|
||||||
str(video_path),
|
str(video_path),
|
||||||
]
|
]
|
||||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(
|
||||||
|
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||||
|
)
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1,15 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
@@ -18,7 +32,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1,15 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .optimizers import OptimizerConfig as OptimizerConfig
|
from .optimizers import OptimizerConfig as OptimizerConfig
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
return "adam"
|
return "adam"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def build(self) -> torch.optim.Optimizer:
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .act.configuration_act import ACTConfig as ACTConfig
|
from .act.configuration_act import ACTConfig as ACTConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class ACTConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -222,6 +222,8 @@ class ACTTemporalEnsembler:
|
|||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||||
|
self.ensembled_actions = None
|
||||||
|
self.ensembled_actions_count = None
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
|||||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing forward() implementation
|
||||||
class DiffusionModel(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -203,6 +204,7 @@ class DiffusionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.num_inference_steps is None:
|
if config.num_inference_steps is None:
|
||||||
|
# TODO(Steven): Consider type check?
|
||||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||||
else:
|
else:
|
||||||
self.num_inference_steps = config.num_inference_steps
|
self.num_inference_steps = config.num_inference_steps
|
||||||
@@ -333,7 +335,7 @@ class DiffusionModel(nn.Module):
|
|||||||
# Sample a random noising timestep for each item in the batch.
|
# Sample a random noising timestep for each item in the batch.
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
high=self.noise_scheduler.config.num_train_timesteps,
|
high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
|
||||||
size=(trajectory.shape[0],),
|
size=(trajectory.shape[0],),
|
||||||
device=trajectory.device,
|
device=trajectory.device,
|
||||||
).long()
|
).long()
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
@@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
|
|
||||||
def make_policy(
|
def make_policy(
|
||||||
cfg: PreTrainedConfig,
|
cfg: PreTrainedConfig,
|
||||||
device: str | torch.device,
|
|
||||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||||
env_cfg: EnvConfig | None = None,
|
env_cfg: EnvConfig | None = None,
|
||||||
) -> PreTrainedPolicy:
|
) -> PreTrainedPolicy:
|
||||||
@@ -88,7 +86,6 @@ def make_policy(
|
|||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||||
be loaded with the weights from that path.
|
be loaded with the weights from that path.
|
||||||
device (str): the device to load the policy onto.
|
|
||||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||||
@@ -96,7 +93,7 @@ def make_policy(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||||
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
|
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PreTrainedPolicy: _description_
|
PreTrainedPolicy: _description_
|
||||||
@@ -111,7 +108,7 @@ def make_policy(
|
|||||||
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
|
||||||
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
|
||||||
# slower than running natively on MPS.
|
# slower than running natively on MPS.
|
||||||
if cfg.type == "vqbet" and str(device) == "mps":
|
if cfg.type == "vqbet" and cfg.device == "mps":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Current implementation of VQBeT does not support `mps` backend. "
|
"Current implementation of VQBeT does not support `mps` backend. "
|
||||||
"Please use `cpu` or `cuda` backend."
|
"Please use `cpu` or `cuda` backend."
|
||||||
@@ -145,7 +142,7 @@ def make_policy(
|
|||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
|
|
||||||
policy.to(device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|||||||
@@ -69,12 +69,12 @@ def create_stats_buffers(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||||
buffer = nn.ParameterDict(
|
buffer = nn.ParameterDict(
|
||||||
{
|
{
|
||||||
"min": nn.Parameter(min, requires_grad=False),
|
"min": nn.Parameter(min_norm, requires_grad=False),
|
||||||
"max": nn.Parameter(max, requires_grad=False),
|
"max": nn.Parameter(max_norm, requires_grad=False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -170,12 +170,12 @@ class Normalize(nn.Module):
|
|||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = buffer["min"]
|
min_norm = buffer["min"]
|
||||||
max = buffer["max"]
|
max_norm = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
batch[key] = batch[key] * 2 - 1
|
batch[key] = batch[key] * 2 - 1
|
||||||
else:
|
else:
|
||||||
@@ -243,12 +243,12 @@ class Unnormalize(nn.Module):
|
|||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
batch[key] = batch[key] * std + mean
|
batch[key] = batch[key] * std + mean
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = buffer["min"]
|
min_norm = buffer["min"]
|
||||||
max = buffer["max"]
|
max_norm = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||||
batch[key] = (batch[key] + 1) / 2
|
batch[key] = (batch[key] + 1) / 2
|
||||||
batch[key] = batch[key] * (max - min) + min
|
batch[key] = batch[key] * (max_norm - min_norm) + min_norm
|
||||||
else:
|
else:
|
||||||
raise ValueError(norm_mode)
|
raise ValueError(norm_mode)
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.common.optim.optimizers import AdamWConfig
|
from lerobot.common.optim.optimizers import AdamWConfig
|
||||||
@@ -76,7 +90,8 @@ class PI0Config(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||||
|
# Input validation (not exhaustive).
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
@@ -31,7 +45,7 @@ def main():
|
|||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
policy = make_policy(cfg, ds_meta=dataset.meta)
|
||||||
|
|
||||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -41,7 +55,7 @@ def main():
|
|||||||
with open(save_dir / "noise.pkl", "rb") as f:
|
with open(save_dir / "noise.pkl", "rb") as f:
|
||||||
noise = pickle.load(f)
|
noise = pickle.load(f)
|
||||||
|
|
||||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
|
||||||
norm_stats = json.load(f)
|
norm_stats = json.load(f)
|
||||||
|
|
||||||
# Override stats
|
# Override stats
|
||||||
@@ -87,7 +101,7 @@ def main():
|
|||||||
|
|
||||||
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
||||||
cfg.pretrained_path = ckpt_torch_dir
|
cfg.pretrained_path = ckpt_torch_dir
|
||||||
policy = make_policy(cfg, device, dataset_meta)
|
policy = make_policy(cfg, dataset_meta)
|
||||||
|
|
||||||
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
|
||||||
# loss_dict["loss"].backward()
|
# loss_dict["loss"].backward()
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from transformers import GemmaConfig, PaliGemmaConfig
|
from transformers import GemmaConfig, PaliGemmaConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Convert pi0 parameters from Jax to Pytorch
|
Convert pi0 parameters from Jax to Pytorch
|
||||||
|
|
||||||
@@ -304,7 +318,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
|||||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||||
|
|
||||||
|
|
||||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
|
||||||
# Break down orbax ckpts - they are in OCDBT
|
# Break down orbax ckpts - they are in OCDBT
|
||||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||||
# process projection params
|
# process projection params
|
||||||
@@ -418,6 +432,6 @@ if __name__ == "__main__":
|
|||||||
convert_pi0_checkpoint(
|
convert_pi0_checkpoint(
|
||||||
checkpoint_dir=args.checkpoint_dir,
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
precision=args.precision,
|
precision=args.precision,
|
||||||
tokenizer_id=args.tokenizer_hub_id,
|
_tokenizer_id=args.tokenizer_hub_id,
|
||||||
output_path=args.output_path,
|
output_path=args.output_path,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,22 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
|
# TODO(Steven): Consider settings this a dependency constraint
|
||||||
if Version(torch.__version__) > Version("2.5.0"):
|
if Version(torch.__version__) > Version("2.5.0"):
|
||||||
# Ffex attention is only available from torch 2.5 onwards
|
# Ffex attention is only available from torch 2.5 onwards
|
||||||
from torch.nn.attention.flex_attention import (
|
from torch.nn.attention.flex_attention import (
|
||||||
@@ -107,7 +122,7 @@ def flex_attention_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||||
attn_output, attention_weights = flex_attention(
|
attn_output, _attention_weights = flex_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -73,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
cache_dir: str | Path | None = None,
|
cache_dir: str | Path | None = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
map_location: str = "cpu",
|
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T:
|
) -> T:
|
||||||
@@ -98,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
if os.path.isdir(model_id):
|
if os.path.isdir(model_id):
|
||||||
print("Loading weights from local directory")
|
print("Loading weights from local directory")
|
||||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
model_file = hf_hub_download(
|
model_file = hf_hub_download(
|
||||||
@@ -112,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
token=token,
|
token=token,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
|
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
|
||||||
except HfHubHTTPError as e:
|
except HfHubHTTPError as e:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
policy.to(map_location)
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class TDMPCConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if self.n_gaussian_samples <= 0:
|
if self.n_gaussian_samples <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||||
|
|||||||
@@ -88,6 +88,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
for param in self.model_target.parameters():
|
for param in self.model_target.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
self._queues = None
|
||||||
|
self._prev_mean: torch.Tensor | None = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -108,7 +111,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||||
# CEM for the next step.
|
# CEM for the next step.
|
||||||
self._prev_mean: torch.Tensor | None = None
|
self._prev_mean = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
@@ -514,6 +517,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): forward implementation missing
|
||||||
class TDMPCTOLD(nn.Module):
|
class TDMPCTOLD(nn.Module):
|
||||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class VQBeTConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
"""Input validation (not exhaustive)."""
|
# Input validation (not exhaustive).
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
if not self.vision_backbone.startswith("resnet"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
self.vqbet = VQBeTModel(config)
|
self.vqbet = VQBeTModel(config)
|
||||||
|
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -535,7 +537,7 @@ class VQBeTHead(nn.Module):
|
|||||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||||
)
|
)
|
||||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||||
NT, G, choices = cbet_probs.shape
|
NT, _G, choices = cbet_probs.shape
|
||||||
sampled_centers = einops.rearrange(
|
sampled_centers = einops.rearrange(
|
||||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||||
"(NT G) 1 -> NT G",
|
"(NT G) 1 -> NT G",
|
||||||
@@ -578,7 +580,7 @@ class VQBeTHead(nn.Module):
|
|||||||
"decoded_action": decoded_action,
|
"decoded_action": decoded_action,
|
||||||
}
|
}
|
||||||
|
|
||||||
def loss_fn(self, pred, target, **kwargs):
|
def loss_fn(self, pred, target, **_kwargs):
|
||||||
"""
|
"""
|
||||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||||
|
|
||||||
@@ -605,7 +607,7 @@ class VQBeTHead(nn.Module):
|
|||||||
# Figure out the loss for the actions.
|
# Figure out the loss for the actions.
|
||||||
# First, we need to find the closest cluster center for each ground truth action.
|
# First, we need to find the closest cluster center for each ground truth action.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
_state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||||
|
|
||||||
# Now we can compute the loss.
|
# Now we can compute the loss.
|
||||||
|
|
||||||
@@ -762,6 +764,7 @@ def _replace_submodules(
|
|||||||
return root_module
|
return root_module
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
|
||||||
class VqVae(nn.Module):
|
class VqVae(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -876,13 +879,13 @@ class FocalLoss(nn.Module):
|
|||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.size_average = size_average
|
self.size_average = size_average
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, forward_input, target):
|
||||||
if len(input.shape) == 3:
|
if len(forward_input.shape) == 3:
|
||||||
N, T, _ = input.shape
|
N, T, _ = forward_input.shape
|
||||||
logpt = F.log_softmax(input, dim=-1)
|
logpt = F.log_softmax(forward_input, dim=-1)
|
||||||
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
||||||
elif len(input.shape) == 2:
|
elif len(forward_input.shape) == 2:
|
||||||
logpt = F.log_softmax(input, dim=-1)
|
logpt = F.log_softmax(forward_input, dim=-1)
|
||||||
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
||||||
pt = logpt.exp()
|
pt = logpt.exp()
|
||||||
|
|
||||||
|
|||||||
@@ -34,63 +34,58 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
|||||||
|
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
|
|
||||||
"""
|
# This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
#
|
||||||
|
# - Vector Quantize PyTorch code is licensed under the MIT License:
|
||||||
|
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
|
#
|
||||||
|
# - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||||
|
# Original source: https://github.com/karpathy/nanoGPT
|
||||||
|
#
|
||||||
|
# We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
||||||
|
|
||||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
# This is a part for nanoGPT that utilizes code from the following repository:
|
||||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
#
|
||||||
|
# - Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
# Original source: https://github.com/karpathy/nanoGPT
|
||||||
Original source: https://github.com/karpathy/nanoGPT
|
#
|
||||||
|
# - The nanoGPT code is licensed under the MIT License:
|
||||||
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
#
|
||||||
"""
|
# MIT License
|
||||||
|
#
|
||||||
"""
|
# Copyright (c) 2022 Andrej Karpathy
|
||||||
This is a part for nanoGPT that utilizes code from the following repository:
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
- Andrej Karpathy's nanoGPT implementation in PyTorch.
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
Original source: https://github.com/karpathy/nanoGPT
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
- The nanoGPT code is licensed under the MIT License:
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
MIT License
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
Copyright (c) 2022 Andrej Karpathy
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
in the Software without restriction, including without limitation the rights
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
furnished to do so, subject to the following conditions:
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
The above copyright notice and this permission notice shall be included in all
|
#
|
||||||
copies or substantial portions of the Software.
|
# - We've made some changes to the original code to adapt it to our needs.
|
||||||
|
#
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# Changed variable names:
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# - n_head -> gpt_n_head
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
# - n_embd -> gpt_hidden_dim
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# - block_size -> gpt_block_size
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
# - n_layer -> gpt_n_layer
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
#
|
||||||
SOFTWARE.
|
#
|
||||||
|
# class GPT(nn.Module):
|
||||||
- We've made some changes to the original code to adapt it to our needs.
|
# - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
||||||
|
# - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
||||||
Changed variable names:
|
# - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
||||||
- n_head -> gpt_n_head
|
|
||||||
- n_embd -> gpt_hidden_dim
|
|
||||||
- block_size -> gpt_block_size
|
|
||||||
- n_layer -> gpt_n_layer
|
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
|
||||||
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
|
||||||
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
|
||||||
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class CausalSelfAttention(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
@@ -200,9 +195,9 @@ class GPT(nn.Module):
|
|||||||
n_params = sum(p.numel() for p in self.parameters())
|
n_params = sum(p.numel() for p in self.parameters())
|
||||||
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
||||||
|
|
||||||
def forward(self, input, targets=None):
|
def forward(self, forward_input):
|
||||||
device = input.device
|
device = forward_input.device
|
||||||
b, t, d = input.size()
|
_, t, _ = forward_input.size()
|
||||||
assert t <= self.config.gpt_block_size, (
|
assert t <= self.config.gpt_block_size, (
|
||||||
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||||
)
|
)
|
||||||
@@ -211,7 +206,7 @@ class GPT(nn.Module):
|
|||||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||||
|
|
||||||
# forward the GPT model itself
|
# forward the GPT model itself
|
||||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
tok_emb = self.transformer.wte(forward_input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||||
x = self.transformer.drop(tok_emb + pos_emb)
|
x = self.transformer.drop(tok_emb + pos_emb)
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
@@ -285,51 +280,48 @@ class GPT(nn.Module):
|
|||||||
return decay, no_decay
|
return decay, no_decay
|
||||||
|
|
||||||
|
|
||||||
"""
|
# This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
#
|
||||||
|
# - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
#
|
||||||
|
# - The vector-quantize-pytorch code is licensed under the MIT License:
|
||||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
#
|
||||||
|
# MIT License
|
||||||
MIT License
|
#
|
||||||
|
# Copyright (c) 2020 Phil Wang
|
||||||
Copyright (c) 2020 Phil Wang
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
# in the Software without restriction, including without limitation the rights
|
||||||
in the Software without restriction, including without limitation the rights
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
# furnished to do so, subject to the following conditions:
|
||||||
furnished to do so, subject to the following conditions:
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
The above copyright notice and this permission notice shall be included in all
|
# copies or substantial portions of the Software.
|
||||||
copies or substantial portions of the Software.
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# SOFTWARE.
|
||||||
SOFTWARE.
|
#
|
||||||
|
# - We've made some changes to the original code to adapt it to our needs.
|
||||||
- We've made some changes to the original code to adapt it to our needs.
|
#
|
||||||
|
# class ResidualVQ(nn.Module):
|
||||||
class ResidualVQ(nn.Module):
|
# - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
||||||
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
# This enables the user to save an indicator whether the codebook is frozen or not.
|
||||||
This enables the user to save an indicator whether the codebook is frozen or not.
|
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
# This is to make the function name more descriptive.
|
||||||
This is to make the function name more descriptive.
|
#
|
||||||
|
# class VectorQuantize(nn.Module):
|
||||||
class VectorQuantize(nn.Module):
|
# - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
||||||
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
# These parameters are not used in the code.
|
||||||
These parameters are not used in the code.
|
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
# This is to make the function name more descriptive.
|
||||||
This is to make the function name more descriptive.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualVQ(nn.Module):
|
class ResidualVQ(nn.Module):
|
||||||
@@ -479,6 +471,9 @@ class ResidualVQ(nn.Module):
|
|||||||
|
|
||||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||||
|
|
||||||
|
null_indices = None
|
||||||
|
null_loss = None
|
||||||
|
|
||||||
# sample a layer index at which to dropout further residual quantization
|
# sample a layer index at which to dropout further residual quantization
|
||||||
# also prepare null indices and loss
|
# also prepare null indices and loss
|
||||||
|
|
||||||
@@ -933,7 +928,7 @@ class VectorQuantize(nn.Module):
|
|||||||
return quantize, embed_ind, loss
|
return quantize, embed_ind, loss
|
||||||
|
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*_args, **_kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file contains utilities for recording frames from Intel Realsense cameras.
|
This file contains utilities for recording frames from Intel Realsense cameras.
|
||||||
"""
|
"""
|
||||||
@@ -63,9 +77,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
|||||||
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
|
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
img.save(str(path), quality=100)
|
img.save(str(path), quality=100)
|
||||||
logging.info(f"Saved image: {path}")
|
logging.info("Saved image: %s", path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
logging.error("Failed to save image for camera %s frame %s: %s", serial_number, frame_index, e)
|
||||||
|
|
||||||
|
|
||||||
def save_images_from_cameras(
|
def save_images_from_cameras(
|
||||||
@@ -433,7 +447,7 @@ class IntelRealSenseCamera:
|
|||||||
num_tries += 1
|
num_tries += 1
|
||||||
time.sleep(1 / self.fps)
|
time.sleep(1 / self.fps)
|
||||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||||
raise Exception(
|
raise TimeoutError(
|
||||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
||||||
"""
|
"""
|
||||||
@@ -31,7 +45,7 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
|||||||
MAX_OPENCV_INDEX = 60
|
MAX_OPENCV_INDEX = 60
|
||||||
|
|
||||||
|
|
||||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
def find_cameras(max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||||
cameras = []
|
cameras = []
|
||||||
if platform.system() == "Linux":
|
if platform.system() == "Linux":
|
||||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||||
@@ -271,10 +285,20 @@ class OpenCVCamera:
|
|||||||
# when other threads are used to save the images.
|
# when other threads are used to save the images.
|
||||||
cv2.setNumThreads(1)
|
cv2.setNumThreads(1)
|
||||||
|
|
||||||
|
backend = (
|
||||||
|
cv2.CAP_V4L2
|
||||||
|
if platform.system() == "Linux"
|
||||||
|
else cv2.CAP_DSHOW
|
||||||
|
if platform.system() == "Windows"
|
||||||
|
else cv2.CAP_AVFOUNDATION
|
||||||
|
if platform.system() == "Darwin"
|
||||||
|
else cv2.CAP_ANY
|
||||||
|
)
|
||||||
|
|
||||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||||
# First create a temporary camera trying to access `camera_index`,
|
# First create a temporary camera trying to access `camera_index`,
|
||||||
# and verify it is a valid camera by calling `isOpened`.
|
# and verify it is a valid camera by calling `isOpened`.
|
||||||
tmp_camera = cv2.VideoCapture(camera_idx)
|
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||||
is_camera_open = tmp_camera.isOpened()
|
is_camera_open = tmp_camera.isOpened()
|
||||||
# Release camera to make it accessible for `find_camera_indices`
|
# Release camera to make it accessible for `find_camera_indices`
|
||||||
tmp_camera.release()
|
tmp_camera.release()
|
||||||
@@ -297,7 +321,7 @@ class OpenCVCamera:
|
|||||||
# Secondly, create the camera that will be used downstream.
|
# Secondly, create the camera that will be used downstream.
|
||||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||||
# needs to be re-created.
|
# needs to be re-created.
|
||||||
self.camera = cv2.VideoCapture(camera_idx)
|
self.camera = cv2.VideoCapture(camera_idx, backend)
|
||||||
|
|
||||||
if self.fps is not None:
|
if self.fps is not None:
|
||||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -1,14 +1,25 @@
|
|||||||
import logging
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -43,11 +54,6 @@ class RecordControlConfig(ControlConfig):
|
|||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||||
root: str | Path | None = None
|
root: str | Path | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
|
||||||
device: str | None = None # cuda | cpu | mps
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool | None = None
|
|
||||||
# Limit the frames per second. By default, uses the policy fps.
|
# Limit the frames per second. By default, uses the policy fps.
|
||||||
fps: int | None = None
|
fps: int | None = None
|
||||||
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
|
||||||
@@ -90,27 +96,6 @@ class RecordControlConfig(ControlConfig):
|
|||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
# When no device or use_amp are given, use the one from training config.
|
|
||||||
if self.device is None or self.use_amp is None:
|
|
||||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
|
||||||
if self.device is None:
|
|
||||||
self.device = train_cfg.device
|
|
||||||
if self.use_amp is None:
|
|
||||||
self.use_amp = train_cfg.use_amp
|
|
||||||
|
|
||||||
# Automatically switch to available device if necessary
|
|
||||||
if not is_torch_device_available(self.device):
|
|
||||||
auto_device = auto_select_torch_device()
|
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
||||||
self.device = auto_device
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
|
|
||||||
@ControlConfig.register_subclass("replay")
|
@ControlConfig.register_subclass("replay")
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
@@ -18,6 +32,7 @@ from termcolor import colored
|
|||||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import get_features_from_robot
|
from lerobot.common.datasets.utils import get_features_from_robot
|
||||||
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||||
@@ -179,8 +194,6 @@ def record_episode(
|
|||||||
episode_time_s,
|
episode_time_s,
|
||||||
display_cameras,
|
display_cameras,
|
||||||
policy,
|
policy,
|
||||||
device,
|
|
||||||
use_amp,
|
|
||||||
fps,
|
fps,
|
||||||
single_task,
|
single_task,
|
||||||
):
|
):
|
||||||
@@ -191,8 +204,6 @@ def record_episode(
|
|||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
events=events,
|
events=events,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=device,
|
|
||||||
use_amp=use_amp,
|
|
||||||
fps=fps,
|
fps=fps,
|
||||||
teleoperate=policy is None,
|
teleoperate=policy is None,
|
||||||
single_task=single_task,
|
single_task=single_task,
|
||||||
@@ -207,9 +218,7 @@ def control_loop(
|
|||||||
display_cameras=False,
|
display_cameras=False,
|
||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
events=None,
|
events=None,
|
||||||
policy=None,
|
policy: PreTrainedPolicy = None,
|
||||||
device: torch.device | str | None = None,
|
|
||||||
use_amp: bool | None = None,
|
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
single_task: str | None = None,
|
single_task: str | None = None,
|
||||||
):
|
):
|
||||||
@@ -232,9 +241,6 @@ def control_loop(
|
|||||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||||
|
|
||||||
if isinstance(device, str):
|
|
||||||
device = get_safe_torch_device(device)
|
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
while timestamp < control_time_s:
|
while timestamp < control_time_s:
|
||||||
@@ -246,7 +252,9 @@ def control_loop(
|
|||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
pred_action = predict_action(observation, policy, device, use_amp)
|
pred_action = predict_action(
|
||||||
|
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
|
)
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
# so action actually sent is saved in the dataset.
|
# so action actually sent is saved in the dataset.
|
||||||
action = robot.send_action(pred_action)
|
action = robot.send_action(pred_action)
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -155,7 +169,8 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
|||||||
return steps
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes, mock=False):
|
# TODO(Steven): Similar function in feetch.py, should be moved to a common place.
|
||||||
|
def convert_to_bytes(value, byte, mock=False):
|
||||||
if mock:
|
if mock:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -163,16 +178,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
|
|
||||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||||
# already handles it for us.
|
# already handles it for us.
|
||||||
if bytes == 1:
|
if byte == 1:
|
||||||
data = [
|
data = [
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 2:
|
elif byte == 2:
|
||||||
data = [
|
data = [
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 4:
|
elif byte == 4:
|
||||||
data = [
|
data = [
|
||||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||||
@@ -182,7 +197,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||||
f"{bytes} is provided instead."
|
f"{byte} is provided instead."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -214,9 +229,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
|||||||
all_addr = []
|
all_addr = []
|
||||||
all_bytes = []
|
all_bytes = []
|
||||||
for model in motor_models:
|
for model in motor_models:
|
||||||
addr, bytes = model_ctrl_table[model][data_name]
|
addr, byte = model_ctrl_table[model][data_name]
|
||||||
all_addr.append(addr)
|
all_addr.append(addr)
|
||||||
all_bytes.append(bytes)
|
all_bytes.append(byte)
|
||||||
|
|
||||||
if len(set(all_addr)) != 1:
|
if len(set(all_addr)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -562,6 +577,8 @@ class DynamixelMotorsBus:
|
|||||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||||
low_factor = (start_pos - values[i]) / resolution
|
low_factor = (start_pos - values[i]) / resolution
|
||||||
upp_factor = (end_pos - values[i]) / resolution
|
upp_factor = (end_pos - values[i]) / resolution
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||||
|
|
||||||
if not in_range:
|
if not in_range:
|
||||||
# Get first integer between the two bounds
|
# Get first integer between the two bounds
|
||||||
@@ -582,10 +599,15 @@ class DynamixelMotorsBus:
|
|||||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||||
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
"Auto-correct calibration of motor '%s' by shifting value by {abs(factor)} full turns, "
|
||||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
"from '%s' to '%s'.",
|
||||||
|
name,
|
||||||
|
out_of_range_str,
|
||||||
|
in_range_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||||
@@ -642,8 +664,8 @@ class DynamixelMotorsBus:
|
|||||||
motor_ids = [motor_ids]
|
motor_ids = [motor_ids]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
group.addParam(idx)
|
group.addParam(idx)
|
||||||
|
|
||||||
@@ -660,7 +682,7 @@ class DynamixelMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = group.getData(idx, addr, bytes)
|
value = group.getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
if return_list:
|
if return_list:
|
||||||
@@ -695,13 +717,13 @@ class DynamixelMotorsBus:
|
|||||||
models.append(model)
|
models.append(model)
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
if data_name not in self.group_readers:
|
if data_name not in self.group_readers:
|
||||||
# create new group reader
|
# create new group reader
|
||||||
self.group_readers[group_key] = dxl.GroupSyncRead(
|
self.group_readers[group_key] = dxl.GroupSyncRead(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
self.group_readers[group_key].addParam(idx)
|
self.group_readers[group_key].addParam(idx)
|
||||||
@@ -719,7 +741,7 @@ class DynamixelMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
values = np.array(values)
|
values = np.array(values)
|
||||||
@@ -753,10 +775,10 @@ class DynamixelMotorsBus:
|
|||||||
values = [values]
|
values = [values]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
group.addParam(idx, data)
|
group.addParam(idx, data)
|
||||||
|
|
||||||
for _ in range(num_retry):
|
for _ in range(num_retry):
|
||||||
@@ -807,17 +829,17 @@ class DynamixelMotorsBus:
|
|||||||
values = values.tolist()
|
values = values.tolist()
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
init_group = data_name not in self.group_readers
|
init_group = data_name not in self.group_readers
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key].addParam(idx, data)
|
self.group_writers[group_key].addParam(idx, data)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -134,7 +148,7 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
|||||||
return steps
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bytes(value, bytes, mock=False):
|
def convert_to_bytes(value, byte, mock=False):
|
||||||
if mock:
|
if mock:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -142,16 +156,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
|
|
||||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||||
# already handles it for us.
|
# already handles it for us.
|
||||||
if bytes == 1:
|
if byte == 1:
|
||||||
data = [
|
data = [
|
||||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 2:
|
elif byte == 2:
|
||||||
data = [
|
data = [
|
||||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||||
]
|
]
|
||||||
elif bytes == 4:
|
elif byte == 4:
|
||||||
data = [
|
data = [
|
||||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||||
@@ -161,7 +175,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||||
f"{bytes} is provided instead."
|
f"{byte} is provided instead."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -193,9 +207,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
|||||||
all_addr = []
|
all_addr = []
|
||||||
all_bytes = []
|
all_bytes = []
|
||||||
for model in motor_models:
|
for model in motor_models:
|
||||||
addr, bytes = model_ctrl_table[model][data_name]
|
addr, byte = model_ctrl_table[model][data_name]
|
||||||
all_addr.append(addr)
|
all_addr.append(addr)
|
||||||
all_bytes.append(bytes)
|
all_bytes.append(byte)
|
||||||
|
|
||||||
if len(set(all_addr)) != 1:
|
if len(set(all_addr)) != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -543,6 +557,8 @@ class FeetechMotorsBus:
|
|||||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||||
low_factor = (start_pos - values[i]) / resolution
|
low_factor = (start_pos - values[i]) / resolution
|
||||||
upp_factor = (end_pos - values[i]) / resolution
|
upp_factor = (end_pos - values[i]) / resolution
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||||
|
|
||||||
if not in_range:
|
if not in_range:
|
||||||
# Get first integer between the two bounds
|
# Get first integer between the two bounds
|
||||||
@@ -563,10 +579,16 @@ class FeetechMotorsBus:
|
|||||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||||
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
"Auto-correct calibration of motor '%s' by shifting value by %s full turns, "
|
||||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
"from '%s' to '%s'.",
|
||||||
|
name,
|
||||||
|
abs(factor),
|
||||||
|
out_of_range_str,
|
||||||
|
in_range_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||||
@@ -660,8 +682,8 @@ class FeetechMotorsBus:
|
|||||||
motor_ids = [motor_ids]
|
motor_ids = [motor_ids]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
group.addParam(idx)
|
group.addParam(idx)
|
||||||
|
|
||||||
@@ -678,7 +700,7 @@ class FeetechMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = group.getData(idx, addr, bytes)
|
value = group.getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
if return_list:
|
if return_list:
|
||||||
@@ -713,7 +735,7 @@ class FeetechMotorsBus:
|
|||||||
models.append(model)
|
models.append(model)
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
if data_name not in self.group_readers:
|
if data_name not in self.group_readers:
|
||||||
@@ -723,7 +745,7 @@ class FeetechMotorsBus:
|
|||||||
|
|
||||||
# create new group reader
|
# create new group reader
|
||||||
self.group_readers[group_key] = scs.GroupSyncRead(
|
self.group_readers[group_key] = scs.GroupSyncRead(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
self.group_readers[group_key].addParam(idx)
|
self.group_readers[group_key].addParam(idx)
|
||||||
@@ -741,7 +763,7 @@ class FeetechMotorsBus:
|
|||||||
|
|
||||||
values = []
|
values = []
|
||||||
for idx in motor_ids:
|
for idx in motor_ids:
|
||||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
|
|
||||||
values = np.array(values)
|
values = np.array(values)
|
||||||
@@ -778,10 +800,10 @@ class FeetechMotorsBus:
|
|||||||
values = [values]
|
values = [values]
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||||
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
group.addParam(idx, data)
|
group.addParam(idx, data)
|
||||||
|
|
||||||
for _ in range(num_retry):
|
for _ in range(num_retry):
|
||||||
@@ -832,17 +854,17 @@ class FeetechMotorsBus:
|
|||||||
values = values.tolist()
|
values = values.tolist()
|
||||||
|
|
||||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
addr, byte = self.model_ctrl_table[model][data_name]
|
||||||
group_key = get_group_sync_key(data_name, motor_names)
|
group_key = get_group_sync_key(data_name, motor_names)
|
||||||
|
|
||||||
init_group = data_name not in self.group_readers
|
init_group = data_name not in self.group_readers
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key] = scs.GroupSyncWrite(
|
self.group_writers[group_key] = scs.GroupSyncWrite(
|
||||||
self.port_handler, self.packet_handler, addr, bytes
|
self.port_handler, self.packet_handler, addr, byte
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, value in zip(motor_ids, values, strict=True):
|
for idx, value in zip(motor_ids, values, strict=True):
|
||||||
data = convert_to_bytes(value, bytes, self.mock)
|
data = convert_to_bytes(value, byte, self.mock)
|
||||||
if init_group:
|
if init_group:
|
||||||
self.group_writers[group_key].addParam(idx, data)
|
self.group_writers[group_key].addParam(idx, data)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from lerobot.common.robot_devices.motors.configs import (
|
from lerobot.common.robot_devices.motors.configs import (
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
"""Logic to calibrate a robot arm built with dynamixel motors"""
|
||||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Logic to calibrate a robot arm built with feetech motors"""
|
"""Logic to calibrate a robot arm built with feetech motors"""
|
||||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||||
|
|
||||||
@@ -81,6 +95,8 @@ def move_to_calibrate(
|
|||||||
while_move_hook=None,
|
while_move_hook=None,
|
||||||
):
|
):
|
||||||
initial_pos = arm.read("Present_Position", motor_name)
|
initial_pos = arm.read("Present_Position", motor_name)
|
||||||
|
p_present_pos = None
|
||||||
|
n_present_pos = None
|
||||||
|
|
||||||
if positive_first:
|
if positive_first:
|
||||||
p_present_pos = move_until_block(
|
p_present_pos = move_until_block(
|
||||||
@@ -182,7 +198,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
||||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
||||||
|
|
||||||
def in_between_move_hook():
|
def in_between_move_hook_elbow():
|
||||||
nonlocal arm, calib
|
nonlocal arm, calib
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
ef_pos = arm.read("Present_Position", "elbow_flex")
|
ef_pos = arm.read("Present_Position", "elbow_flex")
|
||||||
@@ -193,14 +209,14 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
|
|
||||||
print("Calibrate elbow_flex")
|
print("Calibrate elbow_flex")
|
||||||
calib["elbow_flex"] = move_to_calibrate(
|
calib["elbow_flex"] = move_to_calibrate(
|
||||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook_elbow
|
||||||
)
|
)
|
||||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||||
|
|
||||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def in_between_move_hook():
|
def in_between_move_hook_shoulder():
|
||||||
nonlocal arm, calib
|
nonlocal arm, calib
|
||||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
||||||
|
|
||||||
@@ -210,7 +226,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
|||||||
"shoulder_lift",
|
"shoulder_lift",
|
||||||
invert_drive_mode=True,
|
invert_drive_mode=True,
|
||||||
positive_first=False,
|
positive_first=False,
|
||||||
in_between_move_hook=in_between_move_hook,
|
in_between_move_hook=in_between_move_hook_shoulder,
|
||||||
)
|
)
|
||||||
# add an 30 steps as offset to align with body
|
# add an 30 steps as offset to align with body
|
||||||
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
@@ -53,14 +67,14 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if calib_file.exists():
|
if calib_file.exists():
|
||||||
with open(calib_file) as f:
|
with open(calib_file, encoding="utf-8") as f:
|
||||||
calibration = json.load(f)
|
calibration = json.load(f)
|
||||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||||
else:
|
else:
|
||||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||||
with open(calib_file, "w") as f:
|
with open(calib_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(calibration, f)
|
json.dump(calibration, f)
|
||||||
try:
|
try:
|
||||||
motors_bus.set_calibration(calibration)
|
motors_bus.set_calibration(calibration)
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
||||||
and send orders to its motors.
|
and send orders to its motors.
|
||||||
"""
|
"""
|
||||||
@@ -33,8 +47,10 @@ def ensure_safe_goal_position(
|
|||||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||||
f" requested relative goal position target: {diff}\n"
|
" requested relative goal position target: %s\n"
|
||||||
f" clamped relative goal position target: {safe_diff}"
|
" clamped relative goal position target: %s",
|
||||||
|
diff,
|
||||||
|
safe_diff,
|
||||||
)
|
)
|
||||||
|
|
||||||
return safe_goal_pos
|
return safe_goal_pos
|
||||||
@@ -231,6 +247,8 @@ class ManipulatorRobot:
|
|||||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Robot type {self.robot_type} is not supported")
|
||||||
|
|
||||||
# We assume that at connection time, arms are in a rest position, and torque can
|
# We assume that at connection time, arms are in a rest position, and torque can
|
||||||
# be safely disabled to run calibration and/or set robot preset configurations.
|
# be safely disabled to run calibration and/or set robot preset configurations.
|
||||||
@@ -288,7 +306,7 @@ class ManipulatorRobot:
|
|||||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||||
|
|
||||||
if arm_calib_path.exists():
|
if arm_calib_path.exists():
|
||||||
with open(arm_calib_path) as f:
|
with open(arm_calib_path, encoding="utf-8") as f:
|
||||||
calibration = json.load(f)
|
calibration = json.load(f)
|
||||||
else:
|
else:
|
||||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||||
@@ -308,7 +326,7 @@ class ManipulatorRobot:
|
|||||||
|
|
||||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(arm_calib_path, "w") as f:
|
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(calibration, f)
|
json.dump(calibration, f)
|
||||||
|
|
||||||
return calibration
|
return calibration
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -248,14 +262,14 @@ class MobileManipulator:
|
|||||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||||
|
|
||||||
if arm_calib_path.exists():
|
if arm_calib_path.exists():
|
||||||
with open(arm_calib_path) as f:
|
with open(arm_calib_path, encoding="utf-8") as f:
|
||||||
calibration = json.load(f)
|
calibration = json.load(f)
|
||||||
else:
|
else:
|
||||||
print(f"Missing calibration file '{arm_calib_path}'")
|
print(f"Missing calibration file '{arm_calib_path}'")
|
||||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(arm_calib_path, "w") as f:
|
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(calibration, f)
|
json.dump(calibration, f)
|
||||||
|
|
||||||
return calibration
|
return calibration
|
||||||
@@ -358,6 +372,7 @@ class MobileManipulator:
|
|||||||
|
|
||||||
present_speed = self.last_present_speed
|
present_speed = self.last_present_speed
|
||||||
|
|
||||||
|
# TODO(Steven): [WARN] Plenty of general exceptions
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[DEBUG] Error decoding video message: {e}")
|
print(f"[DEBUG] Error decoding video message: {e}")
|
||||||
# If decode fails, fall back to old data
|
# If decode fails, fall back to old data
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.configs import (
|
from lerobot.common.robot_devices.robots.configs import (
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|||||||
@@ -68,9 +68,9 @@ class TimeBenchmark(ContextDecorator):
|
|||||||
Block took approximately 10.00 milliseconds
|
Block took approximately 10.00 milliseconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, print=False):
|
def __init__(self, print_time=False):
|
||||||
self.local = threading.local()
|
self.local = threading.local()
|
||||||
self.print_time = print
|
self.print_time = print_time
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.local.start_time = time.perf_counter()
|
self.local.start_time = time.perf_counter()
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, Type, TypeVar
|
from typing import Any, Type, TypeVar
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
|||||||
else:
|
else:
|
||||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||||
package_exists = False
|
package_exists = False
|
||||||
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
logging.debug("Detected %s version: %s", {pkg_name}, package_version)
|
||||||
if return_version:
|
if return_version:
|
||||||
return package_exists, package_version
|
return package_exists, package_version
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ class AverageMeter:
|
|||||||
def __init__(self, name: str, fmt: str = ":f"):
|
def __init__(self, name: str, fmt: str = ":f"):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.fmt = fmt
|
self.fmt = fmt
|
||||||
|
self.val = 0.0
|
||||||
|
self.avg = 0.0
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||||
|
try_device = str(try_device)
|
||||||
match try_device:
|
match try_device:
|
||||||
case "cuda":
|
case "cuda":
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
@@ -67,7 +69,7 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
|||||||
case _:
|
case _:
|
||||||
device = torch.device(try_device)
|
device = torch.device(try_device)
|
||||||
if log:
|
if log:
|
||||||
logging.warning(f"Using custom {try_device} device.")
|
logging.warning("Using custom %s device.", try_device)
|
||||||
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
@@ -85,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
|||||||
|
|
||||||
|
|
||||||
def is_torch_device_available(try_device: str) -> bool:
|
def is_torch_device_available(try_device: str) -> bool:
|
||||||
|
try_device = str(try_device) # Ensure try_device is a string
|
||||||
if try_device == "cuda":
|
if try_device == "cuda":
|
||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
elif try_device == "mps":
|
elif try_device == "mps":
|
||||||
@@ -92,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
|
|||||||
elif try_device == "cpu":
|
elif try_device == "cpu":
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown device '{try_device}.")
|
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||||
|
|
||||||
|
|
||||||
def is_amp_available(device: str):
|
def is_amp_available(device: str):
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class WandBLogger:
|
|||||||
resume="must" if cfg.resume else None,
|
resume="must" if cfg.resume else None,
|
||||||
)
|
)
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info("Track this run --> %s", colored(wandb.run.get_url(), "yellow", attrs=["bold"]))
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
def log_policy(self, checkpoint_dir: Path):
|
def log_policy(self, checkpoint_dir: Path):
|
||||||
@@ -108,7 +108,7 @@ class WandBLogger:
|
|||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if not isinstance(v, (int, float, str)):
|
if not isinstance(v, (int, float, str)):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
'WandB logging of key "%s" was ignored as its type is not handled by this wrapper.', k
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||||
|
|||||||
@@ -1,14 +1,26 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot.common import envs, policies # noqa: F401
|
from lerobot.common import envs, policies # noqa: F401
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import EvalConfig
|
from lerobot.configs.default import EvalConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -21,11 +33,6 @@ class EvalPipelineConfig:
|
|||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
job_name: str | None = None
|
job_name: str | None = None
|
||||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
|
||||||
device: str | None = None # cuda | cpu | mps
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool = False
|
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -36,27 +43,6 @@ class EvalPipelineConfig:
|
|||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
# When no device or use_amp are given, use the one from training config.
|
|
||||||
if self.device is None or self.use_amp is None:
|
|
||||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
|
||||||
if self.device is None:
|
|
||||||
self.device = train_cfg.device
|
|
||||||
if self.use_amp is None:
|
|
||||||
self.use_amp = train_cfg.use_amp
|
|
||||||
|
|
||||||
# Automatically switch to available device if necessary
|
|
||||||
if not is_torch_device_available(self.device):
|
|
||||||
auto_device = auto_select_torch_device()
|
|
||||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
|
||||||
self.device = auto_device
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||||
@@ -73,11 +59,6 @@ class EvalPipelineConfig:
|
|||||||
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||||
self.output_dir = Path("outputs/eval") / eval_dir
|
self.output_dir = Path("outputs/eval") / eval_dir
|
||||||
|
|
||||||
if self.device is None:
|
|
||||||
raise ValueError("Set one of the following device: cuda, cpu or mps")
|
|
||||||
elif self.device == "cuda" and self.use_amp is None:
|
|
||||||
raise ValueError("Set 'use_amp' to True or False.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentError
|
from argparse import ArgumentError
|
||||||
|
|||||||
@@ -1,4 +1,18 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -12,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
|||||||
from lerobot.common.optim.optimizers import OptimizerConfig
|
from lerobot.common.optim.optimizers import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
|
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
|
||||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||||
@@ -40,22 +55,42 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
device: str | None = None # cuda | cpu | mp
|
||||||
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
|
# automatic gradient scaling is used.
|
||||||
|
use_amp: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.pretrained_path = None
|
self.pretrained_path = None
|
||||||
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
|
auto_device = auto_select_torch_device()
|
||||||
|
logging.warning("Device '%s' is not available. Switching to '%s'.", self.device, auto_device)
|
||||||
|
self.device = auto_device.type
|
||||||
|
|
||||||
|
# Automatically deactivate AMP if necessary
|
||||||
|
if self.use_amp and not is_amp_available(self.device):
|
||||||
|
logging.warning(
|
||||||
|
"Automatic Mixed Precision (amp) is not available on device '%s'. Deactivating AMP.",
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
self.use_amp = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.get_choice_name(self.__class__)
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def observation_delta_indices(self) -> list | None:
|
def observation_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def action_delta_indices(self) -> list | None:
|
def action_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractproperty
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def reward_delta_indices(self) -> list | None:
|
def reward_delta_indices(self) -> list | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -97,7 +132,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
with open(save_directory / CONFIG_NAME, "w", encoding="utf-8") as f, draccus.config_type("json"):
|
||||||
draccus.dump(self, f, indent=4)
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,5 +1,17 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -13,7 +25,6 @@ from lerobot.common import envs
|
|||||||
from lerobot.common.optim import OptimizerConfig
|
from lerobot.common.optim import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
@@ -35,10 +46,6 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||||
# regardless of what's provided with the training command at the time of resumption.
|
# regardless of what's provided with the training command at the time of resumption.
|
||||||
resume: bool = False
|
resume: bool = False
|
||||||
device: str | None = None # cuda | cpu | mp
|
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
||||||
# automatic gradient scaling is used.
|
|
||||||
use_amp: bool = False
|
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: int | None = 1000
|
seed: int | None = 1000
|
||||||
@@ -61,18 +68,6 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
self.checkpoint_path = None
|
self.checkpoint_path = None
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
if not self.device:
|
|
||||||
logging.warning("No device specified, trying to infer device automatically")
|
|
||||||
device = auto_select_torch_device()
|
|
||||||
self.device = device.type
|
|
||||||
|
|
||||||
# Automatically deactivate AMP if necessary
|
|
||||||
if self.use_amp and not is_amp_available(self.device):
|
|
||||||
logging.warning(
|
|
||||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
||||||
)
|
|
||||||
self.use_amp = False
|
|
||||||
|
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
if policy_path:
|
||||||
@@ -128,7 +123,10 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
return draccus.encode(self)
|
return draccus.encode(self)
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
with (
|
||||||
|
open(save_directory / TRAIN_CONFIG_NAME, "w", encoding="utf-8") as f,
|
||||||
|
draccus.config_type("json"),
|
||||||
|
):
|
||||||
draccus.dump(self, f, indent=4)
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
# Note: We subclass str so that serialization is straightforward
|
# Note: We subclass str so that serialization is straightforward
|
||||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This script configure a single motor at a time to a given ID and baudrate.
|
This script configure a single motor at a time to a given ID and baudrate.
|
||||||
|
|
||||||
@@ -77,6 +90,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
|||||||
print("Scanning all baudrates and motor indices")
|
print("Scanning all baudrates and motor indices")
|
||||||
all_baudrates = set(series_baudrate_table.values())
|
all_baudrates = set(series_baudrate_table.values())
|
||||||
motor_index = -1 # Set the motor index to an out-of-range value.
|
motor_index = -1 # Set the motor index to an out-of-range value.
|
||||||
|
baudrate = None
|
||||||
|
|
||||||
for baudrate in all_baudrates:
|
for baudrate in all_baudrates:
|
||||||
motor_bus.set_bus_baudrate(baudrate)
|
motor_bus.set_bus_baudrate(baudrate)
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Utilities to control a robot.
|
Utilities to control a robot.
|
||||||
|
|
||||||
@@ -254,7 +267,7 @@ def record(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
@@ -285,8 +298,6 @@ def record(
|
|||||||
episode_time_s=cfg.episode_time_s,
|
episode_time_s=cfg.episode_time_s,
|
||||||
display_cameras=cfg.display_cameras,
|
display_cameras=cfg.display_cameras,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=cfg.device,
|
|
||||||
use_amp=cfg.use_amp,
|
|
||||||
fps=cfg.fps,
|
fps=cfg.fps,
|
||||||
single_task=cfg.single_task,
|
single_task=cfg.single_task,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Utilities to control a robot in simulation.
|
Utilities to control a robot in simulation.
|
||||||
|
|
||||||
@@ -68,6 +81,7 @@ This might require a sudo permission to allow your terminal to monitor keyboard
|
|||||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO(Steven): This script should be updated to use the new robot API and the new dataset API.
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
|||||||
|
|
||||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
cuda_version = torch.version.cuda if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||||
|
|||||||
@@ -259,6 +259,10 @@ def eval_policy(
|
|||||||
threads = [] # for video saving threads
|
threads = [] # for video saving threads
|
||||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||||
|
|
||||||
|
video_paths: list[str] = [] # max_episodes_rendered > 0:
|
||||||
|
ep_frames: list[np.ndarray] = [] # max_episodes_rendered > 0
|
||||||
|
episode_data: dict | None = None # return_episode_data == True
|
||||||
|
|
||||||
# Callback for visualization.
|
# Callback for visualization.
|
||||||
def render_frame(env: gym.vector.VectorEnv):
|
def render_frame(env: gym.vector.VectorEnv):
|
||||||
# noqa: B023
|
# noqa: B023
|
||||||
@@ -271,19 +275,11 @@ def eval_policy(
|
|||||||
# Here we must render all frames and discard any we don't need.
|
# Here we must render all frames and discard any we don't need.
|
||||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
|
||||||
video_paths: list[str] = []
|
|
||||||
|
|
||||||
if return_episode_data:
|
|
||||||
episode_data: dict | None = None
|
|
||||||
|
|
||||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||||
for batch_ix in progbar:
|
for batch_ix in progbar:
|
||||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||||
# step.
|
# step.
|
||||||
if max_episodes_rendered > 0:
|
|
||||||
ep_frames: list[np.ndarray] = []
|
|
||||||
|
|
||||||
if start_seed is None:
|
if start_seed is None:
|
||||||
seeds = None
|
seeds = None
|
||||||
@@ -320,13 +316,19 @@ def eval_policy(
|
|||||||
else:
|
else:
|
||||||
all_seeds.append(None)
|
all_seeds.append(None)
|
||||||
|
|
||||||
# FIXME: episode_data is either None or it doesn't exist
|
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
|
if episode_data is None:
|
||||||
|
start_data_index = 0
|
||||||
|
elif isinstance(episode_data, dict):
|
||||||
|
start_data_index = episode_data["index"][-1].item() + 1
|
||||||
|
else:
|
||||||
|
start_data_index = 0
|
||||||
|
|
||||||
this_episode_data = _compile_episode_data(
|
this_episode_data = _compile_episode_data(
|
||||||
rollout_data,
|
rollout_data,
|
||||||
done_indices,
|
done_indices,
|
||||||
start_episode_index=batch_ix * env.num_envs,
|
start_episode_index=batch_ix * env.num_envs,
|
||||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
start_data_index=start_data_index,
|
||||||
fps=env.unwrapped.metadata["render_fps"],
|
fps=env.unwrapped.metadata["render_fps"],
|
||||||
)
|
)
|
||||||
if episode_data is None:
|
if episode_data is None:
|
||||||
@@ -453,12 +455,13 @@ def _compile_episode_data(
|
|||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(Steven): [WARN] Redefining built-in 'eval'
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def eval_main(cfg: EvalPipelineConfig):
|
def eval_main(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@@ -470,14 +473,14 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
|
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
device=device,
|
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy,
|
||||||
@@ -489,7 +492,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
with open(Path(cfg.output_dir) / "eval_info.json", "w", encoding="utf-8") as f:
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|||||||
@@ -1,3 +1,16 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user