Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots

This commit is contained in:
Simon Alibert
2025-03-10 18:39:48 +01:00
135 changed files with 2177 additions and 514 deletions

View File

@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import packaging.version
V2_MESSAGE = """

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import shutil
from pathlib import Path
@@ -27,6 +28,7 @@ import torch.utils
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
@@ -517,6 +519,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
@@ -562,6 +565,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,

View File

@@ -31,6 +31,7 @@ import packaging.version
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
@@ -325,6 +326,19 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
)
hub_versions = get_repo_versions(repo_id)
if not hub_versions:
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
if target_version in hub_versions:
return f"v{target_version}"

View File

@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import traceback
from pathlib import Path

View File

@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
@@ -57,7 +71,7 @@ def convert_dataset(
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:

View File

@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np