Add push to hub for convert_dataset_v21_to_v30

This commit is contained in:
Remi Cadene
2025-04-21 10:08:25 +02:00
parent 4acf99f622
commit 4375a05a9f

View File

@@ -27,10 +27,11 @@ import pandas as pd
import tqdm import tqdm
from datasets import Dataset from datasets import Dataset
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
@@ -381,20 +382,21 @@ def convert_dataset(
shutil.move(str(root), str(old_root)) shutil.move(str(root), str(old_root))
shutil.move(str(new_root), str(root)) shutil.move(str(new_root), str(root))
# TODO(racdene) hub_api = HfApi()
if False: try:
hub_api = HfApi() hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") except HTTPError as e:
hub_api.delete_files( print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"], pass
repo_id=repo_id, hub_api.delete_files(
revision=branch, delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
repo_type="dataset", repo_id=repo_id,
) revision=branch,
repo_type="dataset",
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") LeRobotDataset(repo_id).push_to_hub()
# LeRobotDataset(repo_id).push_to_hub()
if __name__ == "__main__": if __name__ == "__main__":