Compare commits

...

9 Commits

Author SHA1 Message Date
Steven Palma
7d420de118 [skip ci] chore(pre-commit): bump python version from 3.10 to python3.12 2025-03-03 18:02:18 +01:00
Steven Palma
bf6f89a5b5 fix(examples): Add Tensor type check (#799) 2025-03-03 17:01:04 +01:00
Simon Alibert
8861546ad8 [Security] Add Bandit (#795) 2025-03-01 19:19:26 +01:00
Simon Alibert
9c1a893ee3 [CI] Update Stylebot Permissions (#792) 2025-03-01 12:12:19 +01:00
Simon Alibert
e81c36cf74 Fix dataset version tags (#790) 2025-02-28 14:36:20 +01:00
Simon Alibert
ed83cbd4f2 Switch pyav -> av (#780) 2025-02-28 11:06:55 +01:00
Simon Alibert
2a33b9ad87 Revert "Fix pr_style_bot" (#787) 2025-02-27 16:49:18 +01:00
Quentin Gallouédec
6e85aa13ec Break style to test style bot (#785) 2025-02-27 16:46:06 +01:00
Simon Alibert
af05a1725c Fix pr_style_bot (#786) 2025-02-27 16:43:12 +01:00
11 changed files with 133 additions and 46 deletions

View File

@@ -5,17 +5,50 @@ on:
issue_comment: issue_comment:
types: [created] types: [created]
permissions: permissions: {}
contents: write
pull-requests: write env:
PYTHON_VERSION: "3.10"
jobs: jobs:
run-style-bot: check-permissions:
if: > if: >
contains(github.event.comment.body, '@bot /style') && contains(github.event.comment.body, '@bot /style') &&
github.event.issue.pull_request != null github.event.issue.pull_request != null
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs:
is_authorized: ${{ steps.check_user_permission.outputs.has_permission }}
steps:
- name: Check user permission
id: check_user_permission
uses: actions/github-script@v6
with:
script: |
const comment_user = context.payload.comment.user.login;
const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({
owner: context.repo.owner,
repo: context.repo.repo,
username: comment_user
});
const authorized =
permission.permission === 'admin' ||
permission.permission === 'write';
console.log(
`User ${comment_user} has permission level: ${permission.permission}, ` +
`authorized: ${authorized} (admins & maintainers allowed)`
);
core.setOutput('has_permission', authorized);
run-style-bot:
needs: check-permissions
if: needs.check-permissions.outputs.is_authorized == 'true'
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps: steps:
- name: Extract PR details - name: Extract PR details
id: pr_info id: pr_info
@@ -61,6 +94,8 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Get Ruff Version from pre-commit-config.yaml - name: Get Ruff Version from pre-commit-config.yaml
id: get-ruff-version id: get-ruff-version
@@ -91,6 +126,7 @@ jobs:
# Configure git with the Actions bot user # Configure git with the Actions bot user
git config user.name "github-actions[bot]" git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com" git config user.email "github-actions[bot]@users.noreply.github.com"
git config --local lfs.https://github.com/.locksverify false
# Make sure your 'origin' remote is set to the contributor's fork # Make sure your 'origin' remote is set to the contributor's fork
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git" git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"

View File

@@ -1,7 +1,8 @@
exclude: ^(tests/data) exclude: ^(tests/data)
default_language_version: default_language_version:
python: python3.10 python: python3.12
repos: repos:
##### Style / Misc. #####
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0 rev: v5.0.0
hooks: hooks:
@@ -14,7 +15,7 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/crate-ci/typos - repo: https://github.com/crate-ci/typos
rev: v1.29.10 rev: v1.30.0
hooks: hooks:
- id: typos - id: typos
args: [--force-exclude] args: [--force-exclude]
@@ -23,16 +24,24 @@ repos:
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.6 rev: v0.9.9
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format
##### Security #####
- repo: https://github.com/gitleaks/gitleaks - repo: https://github.com/gitleaks/gitleaks
rev: v8.23.3 rev: v8.24.0
hooks: hooks:
- id: gitleaks - id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit - repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.3.1 rev: v1.4.1
hooks: hooks:
- id: zizmor - id: zizmor
- repo: https://github.com/PyCQA/bandit
rev: 1.8.3
hooks:
- id: bandit
args: ["-c", "pyproject.toml"]
additional_dependencies: ["bandit[toml]"]

View File

@@ -85,7 +85,7 @@ def main():
done = False done = False
while not done: while not done:
for batch in dataloader: for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
loss, _ = policy.forward(batch) loss, _ = policy.forward(batch)
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import logging import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -27,6 +28,7 @@ import torch.utils
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
@@ -517,6 +519,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
branch: str | None = None, branch: str | None = None,
tags: list | None = None, tags: list | None = None,
license: str | None = "apache-2.0", license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True, push_videos: bool = True,
private: bool = False, private: bool = False,
allow_patterns: list[str] | str | None = None, allow_patterns: list[str] | str | None = None,
@@ -562,6 +565,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def pull_from_repo( def pull_from_repo(
self, self,
allow_patterns: list[str] | str | None = None, allow_patterns: list[str] | str | None = None,

View File

@@ -31,6 +31,7 @@ import packaging.version
import torch import torch
from datasets.table import embed_table_storage from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage from PIL import Image as PILImage
from torchvision import transforms from torchvision import transforms
@@ -325,6 +326,19 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
) )
hub_versions = get_repo_versions(repo_id) hub_versions = get_repo_versions(repo_id)
if not hub_versions:
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
if target_version in hub_versions: if target_version in hub_versions:
return f"v{target_version}" return f"v{target_version}"

View File

@@ -57,7 +57,7 @@ def convert_dataset(
dataset.meta.info["codebase_version"] = CODEBASE_VERSION dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root) write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, allow_patterns="meta/") dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file # delete old stats.json file
if (dataset.root / STATS_PATH).is_file: if (dataset.root / STATS_PATH).is_file:

View File

@@ -17,6 +17,7 @@ import logging
import os import os
import os.path as osp import os.path as osp
import platform import platform
import subprocess
from copy import copy from copy import copy
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@@ -165,23 +166,31 @@ def capture_timestamp_utc():
def say(text, blocking=False): def say(text, blocking=False):
# Check if mac, linux, or windows. system = platform.system()
if platform.system() == "Darwin":
cmd = f'say "{text}"'
if not blocking:
cmd += " &"
elif platform.system() == "Linux":
cmd = f'spd-say "{text}"'
if blocking:
cmd += " --wait"
elif platform.system() == "Windows":
# TODO(rcadene): Make blocking option work for Windows
cmd = (
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
)
os.system(cmd) if system == "Darwin":
cmd = ["say", text]
elif system == "Linux":
cmd = ["spd-say", text]
if blocking:
cmd.append("--wait")
elif system == "Windows":
cmd = [
"PowerShell",
"-Command",
"Add-Type -AssemblyName System.Speech; "
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
]
else:
raise RuntimeError("Unsupported operating system for text-to-speech.")
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False): def log_say(text, play_sounds, blocking=False):

View File

@@ -454,7 +454,7 @@ def _compile_episode_data(
@parser.wrap() @parser.wrap()
def eval(cfg: EvalPipelineConfig): def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
# Check device is available # Check device is available
@@ -499,4 +499,4 @@ def eval(cfg: EvalPipelineConfig):
if __name__ == "__main__": if __name__ == "__main__":
init_logging() init_logging()
eval() eval_main()

View File

@@ -194,7 +194,7 @@ def run_server(
] ]
response = requests.get( response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl" f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
) )
response.raise_for_status() response.raise_for_status()
# Split into lines and parse each line as JSON # Split into lines and parse each line as JSON
@@ -318,7 +318,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
def get_dataset_info(repo_id: str) -> IterableNamespace: def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json") response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
)
response.raise_for_status() # Raises an HTTPError for bad responses response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json() dataset_info = response.json()
dataset_info["repo_id"] = repo_id dataset_info["repo_id"] = repo_id

View File

@@ -48,7 +48,7 @@ dependencies = [
"omegaconf>=2.3.0", "omegaconf>=2.3.0",
"opencv-python>=4.9.0", "opencv-python>=4.9.0",
"packaging>=24.2", "packaging>=24.2",
"pyav>=12.0.5", "av>=12.0.5",
"pymunk>=6.6.0", "pymunk>=6.6.0",
"pynput>=1.7.7", "pynput>=1.7.7",
"pyzmq>=26.2.1", "pyzmq>=26.2.1",
@@ -111,10 +111,19 @@ exclude = [
"venv", "venv",
] ]
[tool.ruff.lint] [tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
[tool.bandit]
exclude_dirs = [
"tests",
"benchmarks",
"lerobot/common/datasets/push_dataset_to_hub",
"lerobot/common/datasets/v2/convert_dataset_v1_to_v2",
"lerobot/common/policies/pi0/conversion_scripts",
"lerobot/scripts/push_dataset_to_hub.py",
]
skips = ["B101", "B311", "B404", "B603"]
[tool.typos] [tool.typos]
default.extend-ignore-re = [ default.extend-ignore-re = [