diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 086411b45..5414c76df 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import logging import shutil from pathlib import Path @@ -27,6 +28,7 @@ import torch.utils from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi, snapshot_download from huggingface_hub.constants import REPOCARD_NAME +from huggingface_hub.errors import RevisionNotFoundError from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats @@ -517,6 +519,7 @@ class LeRobotDataset(torch.utils.data.Dataset): branch: str | None = None, tags: list | None = None, license: str | None = "apache-2.0", + tag_version: bool = True, push_videos: bool = True, private: bool = False, allow_patterns: list[str] | str | None = None, @@ -562,6 +565,11 @@ class LeRobotDataset(torch.utils.data.Dataset): ) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) + if tag_version: + with contextlib.suppress(RevisionNotFoundError): + hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + def pull_from_repo( self, allow_patterns: list[str] | str | None = None, diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 89adb163c..7e297b350 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -31,6 +31,7 @@ import packaging.version import torch from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi +from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage from torchvision import transforms @@ -325,6 +326,19 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> ) hub_versions = get_repo_versions(repo_id) + if not hub_versions: + raise RevisionNotFoundError( + f"""Your dataset must be tagged with a codebase version. + Assuming _version_ is the codebase_version value in the info.json, you can run this: + ```python + from huggingface_hub import HfApi + + hub_api = HfApi() + hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset") + ``` + """ + ) + if target_version in hub_versions: return f"v{target_version}" diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index 20bda75b8..163a60038 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -57,7 +57,7 @@ def convert_dataset( dataset.meta.info["codebase_version"] = CODEBASE_VERSION write_info(dataset.meta.info, dataset.root) - dataset.push_to_hub(branch=branch, allow_patterns="meta/") + dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") # delete old stats.json file if (dataset.root / STATS_PATH).is_file: