Compare commits
9 Commits
qgallouede
...
chore/bump
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d420de118 | ||
|
|
bf6f89a5b5 | ||
|
|
8861546ad8 | ||
|
|
9c1a893ee3 | ||
|
|
e81c36cf74 | ||
|
|
ed83cbd4f2 | ||
|
|
2a33b9ad87 | ||
|
|
6e85aa13ec | ||
|
|
af05a1725c |
44
.github/workflows/pr_style_bot.yml
vendored
44
.github/workflows/pr_style_bot.yml
vendored
@@ -5,17 +5,50 @@ on:
|
|||||||
issue_comment:
|
issue_comment:
|
||||||
types: [created]
|
types: [created]
|
||||||
|
|
||||||
permissions:
|
permissions: {}
|
||||||
contents: write
|
|
||||||
pull-requests: write
|
env:
|
||||||
|
PYTHON_VERSION: "3.10"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run-style-bot:
|
check-permissions:
|
||||||
if: >
|
if: >
|
||||||
contains(github.event.comment.body, '@bot /style') &&
|
contains(github.event.comment.body, '@bot /style') &&
|
||||||
github.event.issue.pull_request != null
|
github.event.issue.pull_request != null
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
is_authorized: ${{ steps.check_user_permission.outputs.has_permission }}
|
||||||
|
steps:
|
||||||
|
- name: Check user permission
|
||||||
|
id: check_user_permission
|
||||||
|
uses: actions/github-script@v6
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const comment_user = context.payload.comment.user.login;
|
||||||
|
const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
username: comment_user
|
||||||
|
});
|
||||||
|
|
||||||
|
const authorized =
|
||||||
|
permission.permission === 'admin' ||
|
||||||
|
permission.permission === 'write';
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`User ${comment_user} has permission level: ${permission.permission}, ` +
|
||||||
|
`authorized: ${authorized} (admins & maintainers allowed)`
|
||||||
|
);
|
||||||
|
|
||||||
|
core.setOutput('has_permission', authorized);
|
||||||
|
|
||||||
|
run-style-bot:
|
||||||
|
needs: check-permissions
|
||||||
|
if: needs.check-permissions.outputs.is_authorized == 'true'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
steps:
|
steps:
|
||||||
- name: Extract PR details
|
- name: Extract PR details
|
||||||
id: pr_info
|
id: pr_info
|
||||||
@@ -61,6 +94,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
- name: Get Ruff Version from pre-commit-config.yaml
|
- name: Get Ruff Version from pre-commit-config.yaml
|
||||||
id: get-ruff-version
|
id: get-ruff-version
|
||||||
@@ -91,6 +126,7 @@ jobs:
|
|||||||
# Configure git with the Actions bot user
|
# Configure git with the Actions bot user
|
||||||
git config user.name "github-actions[bot]"
|
git config user.name "github-actions[bot]"
|
||||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
git config --local lfs.https://github.com/.locksverify false
|
||||||
|
|
||||||
# Make sure your 'origin' remote is set to the contributor's fork
|
# Make sure your 'origin' remote is set to the contributor's fork
|
||||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
|
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${HEADREPOFULLNAME}.git"
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
exclude: ^(tests/data)
|
exclude: ^(tests/data)
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.12
|
||||||
repos:
|
repos:
|
||||||
|
##### Style / Misc. #####
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
@@ -14,7 +15,7 @@ repos:
|
|||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- repo: https://github.com/crate-ci/typos
|
- repo: https://github.com/crate-ci/typos
|
||||||
rev: v1.29.10
|
rev: v1.30.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: typos
|
- id: typos
|
||||||
args: [--force-exclude]
|
args: [--force-exclude]
|
||||||
@@ -23,16 +24,24 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.6
|
rev: v0.9.9
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
|
##### Security #####
|
||||||
- repo: https://github.com/gitleaks/gitleaks
|
- repo: https://github.com/gitleaks/gitleaks
|
||||||
rev: v8.23.3
|
rev: v8.24.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: gitleaks
|
- id: gitleaks
|
||||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||||
rev: v1.3.1
|
rev: v1.4.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: zizmor
|
- id: zizmor
|
||||||
|
- repo: https://github.com/PyCQA/bandit
|
||||||
|
rev: 1.8.3
|
||||||
|
hooks:
|
||||||
|
- id: bandit
|
||||||
|
args: ["-c", "pyproject.toml"]
|
||||||
|
additional_dependencies: ["bandit[toml]"]
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def main():
|
|||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||||
loss, _ = policy.forward(batch)
|
loss, _ = policy.forward(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -27,6 +28,7 @@ import torch.utils
|
|||||||
from datasets import concatenate_datasets, load_dataset
|
from datasets import concatenate_datasets, load_dataset
|
||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.constants import REPOCARD_NAME
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
@@ -517,6 +519,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
license: str | None = "apache-2.0",
|
license: str | None = "apache-2.0",
|
||||||
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
@@ -562,6 +565,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
|
if tag_version:
|
||||||
|
with contextlib.suppress(RevisionNotFoundError):
|
||||||
|
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||||
|
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|
||||||
def pull_from_repo(
|
def pull_from_repo(
|
||||||
self,
|
self,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import packaging.version
|
|||||||
import torch
|
import torch
|
||||||
from datasets.table import embed_table_storage
|
from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||||
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
@@ -325,6 +326,19 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
|||||||
)
|
)
|
||||||
hub_versions = get_repo_versions(repo_id)
|
hub_versions = get_repo_versions(repo_id)
|
||||||
|
|
||||||
|
if not hub_versions:
|
||||||
|
raise RevisionNotFoundError(
|
||||||
|
f"""Your dataset must be tagged with a codebase version.
|
||||||
|
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||||
|
```python
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
hub_api = HfApi()
|
||||||
|
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
if target_version in hub_versions:
|
if target_version in hub_versions:
|
||||||
return f"v{target_version}"
|
return f"v{target_version}"
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def convert_dataset(
|
|||||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||||
write_info(dataset.meta.info, dataset.root)
|
write_info(dataset.meta.info, dataset.root)
|
||||||
|
|
||||||
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
|
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||||
|
|
||||||
# delete old stats.json file
|
# delete old stats.json file
|
||||||
if (dataset.root / STATS_PATH).is_file:
|
if (dataset.root / STATS_PATH).is_file:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import platform
|
import platform
|
||||||
|
import subprocess
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -165,23 +166,31 @@ def capture_timestamp_utc():
|
|||||||
|
|
||||||
|
|
||||||
def say(text, blocking=False):
|
def say(text, blocking=False):
|
||||||
# Check if mac, linux, or windows.
|
system = platform.system()
|
||||||
if platform.system() == "Darwin":
|
|
||||||
cmd = f'say "{text}"'
|
|
||||||
if not blocking:
|
|
||||||
cmd += " &"
|
|
||||||
elif platform.system() == "Linux":
|
|
||||||
cmd = f'spd-say "{text}"'
|
|
||||||
if blocking:
|
|
||||||
cmd += " --wait"
|
|
||||||
elif platform.system() == "Windows":
|
|
||||||
# TODO(rcadene): Make blocking option work for Windows
|
|
||||||
cmd = (
|
|
||||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
|
||||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
|
||||||
)
|
|
||||||
|
|
||||||
os.system(cmd)
|
if system == "Darwin":
|
||||||
|
cmd = ["say", text]
|
||||||
|
|
||||||
|
elif system == "Linux":
|
||||||
|
cmd = ["spd-say", text]
|
||||||
|
if blocking:
|
||||||
|
cmd.append("--wait")
|
||||||
|
|
||||||
|
elif system == "Windows":
|
||||||
|
cmd = [
|
||||||
|
"PowerShell",
|
||||||
|
"-Command",
|
||||||
|
"Add-Type -AssemblyName System.Speech; "
|
||||||
|
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unsupported operating system for text-to-speech.")
|
||||||
|
|
||||||
|
if blocking:
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
else:
|
||||||
|
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||||
|
|
||||||
|
|
||||||
def log_say(text, play_sounds, blocking=False):
|
def log_say(text, play_sounds, blocking=False):
|
||||||
|
|||||||
@@ -454,7 +454,7 @@ def _compile_episode_data(
|
|||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def eval(cfg: EvalPipelineConfig):
|
def eval_main(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
@@ -499,4 +499,4 @@ def eval(cfg: EvalPipelineConfig):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_logging()
|
init_logging()
|
||||||
eval()
|
eval_main()
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ def run_server(
|
|||||||
]
|
]
|
||||||
|
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
|
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
# Split into lines and parse each line as JSON
|
# Split into lines and parse each line as JSON
|
||||||
@@ -318,7 +318,9 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
|||||||
|
|
||||||
|
|
||||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||||
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
|
response = requests.get(
|
||||||
|
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
|
||||||
|
)
|
||||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||||
dataset_info = response.json()
|
dataset_info = response.json()
|
||||||
dataset_info["repo_id"] = repo_id
|
dataset_info["repo_id"] = repo_id
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ dependencies = [
|
|||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"opencv-python>=4.9.0",
|
"opencv-python>=4.9.0",
|
||||||
"packaging>=24.2",
|
"packaging>=24.2",
|
||||||
"pyav>=12.0.5",
|
"av>=12.0.5",
|
||||||
"pymunk>=6.6.0",
|
"pymunk>=6.6.0",
|
||||||
"pynput>=1.7.7",
|
"pynput>=1.7.7",
|
||||||
"pyzmq>=26.2.1",
|
"pyzmq>=26.2.1",
|
||||||
@@ -111,10 +111,19 @@ exclude = [
|
|||||||
"venv",
|
"venv",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||||
|
|
||||||
|
[tool.bandit]
|
||||||
|
exclude_dirs = [
|
||||||
|
"tests",
|
||||||
|
"benchmarks",
|
||||||
|
"lerobot/common/datasets/push_dataset_to_hub",
|
||||||
|
"lerobot/common/datasets/v2/convert_dataset_v1_to_v2",
|
||||||
|
"lerobot/common/policies/pi0/conversion_scripts",
|
||||||
|
"lerobot/scripts/push_dataset_to_hub.py",
|
||||||
|
]
|
||||||
|
skips = ["B101", "B311", "B404", "B603"]
|
||||||
|
|
||||||
[tool.typos]
|
[tool.typos]
|
||||||
default.extend-ignore-re = [
|
default.extend-ignore-re = [
|
||||||
|
|||||||
Reference in New Issue
Block a user